Add positive acknowledgements to work around 802.11 + multicast unreliabilities
This commit is contained in:
parent
9ef8013a6b
commit
29984ace33
@ -44,6 +44,7 @@ class SegmentBase(object):
|
|||||||
# Message segment packet types
|
# Message segment packet types
|
||||||
_SEGMENT_TYPE_DATA = 0
|
_SEGMENT_TYPE_DATA = 0
|
||||||
_SEGMENT_TYPE_RETRANSMIT = 1
|
_SEGMENT_TYPE_RETRANSMIT = 1
|
||||||
|
_SEGMENT_TYPE_ACK = 2
|
||||||
|
|
||||||
def magic():
|
def magic():
|
||||||
return SegmentBase._MAGIC
|
return SegmentBase._MAGIC
|
||||||
@ -61,6 +62,10 @@ class SegmentBase(object):
|
|||||||
return SegmentBase._SEGMENT_TYPE_RETRANSMIT
|
return SegmentBase._SEGMENT_TYPE_RETRANSMIT
|
||||||
type_retransmit = staticmethod(type_retransmit)
|
type_retransmit = staticmethod(type_retransmit)
|
||||||
|
|
||||||
|
def type_ack():
|
||||||
|
return SegmentBase._SEGMENT_TYPE_ACK
|
||||||
|
type_ack = staticmethod(type_ack)
|
||||||
|
|
||||||
def header_len():
|
def header_len():
|
||||||
"""Return the header size of SegmentBase packets."""
|
"""Return the header size of SegmentBase packets."""
|
||||||
return SegmentBase._HEADER_LEN
|
return SegmentBase._HEADER_LEN
|
||||||
@ -156,6 +161,8 @@ class SegmentBase(object):
|
|||||||
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
|
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||||
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
|
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
|
||||||
segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
|
segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||||
|
elif set_type == SegmentBase._SEGMENT_TYPE_ACK:
|
||||||
|
segment = AckSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Segment has invalid type.")
|
raise ValueError("Segment has invalid type.")
|
||||||
|
|
||||||
@ -319,6 +326,96 @@ class RetransmitSegment(SegmentBase):
|
|||||||
return self._rt_segment_number
|
return self._rt_segment_number
|
||||||
|
|
||||||
|
|
||||||
|
class AckSegment(SegmentBase):
|
||||||
|
"""A message segment that encapsulates a message acknowledgement."""
|
||||||
|
|
||||||
|
# Ack data format:
|
||||||
|
# 2: acked message sequence number
|
||||||
|
# 20: acked message total data sha1
|
||||||
|
# 4: acked message source IP address
|
||||||
|
_ACK_DATA_TEMPLATE = "! H20sI"
|
||||||
|
_ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE)
|
||||||
|
|
||||||
|
def data_template():
|
||||||
|
return AckSegment._ACK_DATA_TEMPLATE
|
||||||
|
data_template = staticmethod(data_template)
|
||||||
|
|
||||||
|
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
||||||
|
"""Should not be called directly."""
|
||||||
|
if segno != 1 or total_segs != 1:
|
||||||
|
raise ValueError("Acknowledgement messages must have only one segment.")
|
||||||
|
|
||||||
|
SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
|
||||||
|
self._type = SegmentBase._SEGMENT_TYPE_ACK
|
||||||
|
|
||||||
|
def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr):
|
||||||
|
# Sanity checks on the message attributes
|
||||||
|
if not ack_msg_seq_num or type(ack_msg_seq_num) != type(1):
|
||||||
|
raise ValueError("Ack message sequnce number must be an integer.")
|
||||||
|
if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535:
|
||||||
|
raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.")
|
||||||
|
if not ack_master_sha or type(ack_master_sha) != type("") or len(ack_master_sha) != 20:
|
||||||
|
raise ValueError("Ack message SHA1 checksum invalid.")
|
||||||
|
if type(ack_addr) != type(""):
|
||||||
|
raise ValueError("Ack message invalid address.")
|
||||||
|
try:
|
||||||
|
foo = socket.inet_aton(ack_addr)
|
||||||
|
except socket.error:
|
||||||
|
raise ValueError("Ack message invalid address.")
|
||||||
|
_verify_data = staticmethod(_verify_data)
|
||||||
|
|
||||||
|
def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr):
|
||||||
|
"""Pack an ack payload."""
|
||||||
|
addr_data = socket.inet_aton(ack_addr)
|
||||||
|
data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num,
|
||||||
|
ack_master_sha, addr_data)
|
||||||
|
return (data, _sha_data(data))
|
||||||
|
_make_ack_data = staticmethod(_make_ack_data)
|
||||||
|
|
||||||
|
def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr):
|
||||||
|
"""Static constructor for creation from individual attributes."""
|
||||||
|
|
||||||
|
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||||
|
(data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num,
|
||||||
|
ack_master_sha, ack_addr)
|
||||||
|
segment = AckSegment(1, 1, msg_seq_num, data_sha)
|
||||||
|
segment._data_len = AckSegment._ACK_DATA_LEN
|
||||||
|
segment._data = data
|
||||||
|
SegmentBase._validate_address(addr)
|
||||||
|
segment._addr = addr
|
||||||
|
|
||||||
|
segment._ack_msg_seq_num = ack_msg_seq_num
|
||||||
|
segment._ack_master_sha = ack_master_sha
|
||||||
|
segment._ack_addr = ack_addr
|
||||||
|
return segment
|
||||||
|
new_from_parts = staticmethod(new_from_parts)
|
||||||
|
|
||||||
|
def _unpack_data(self, stream, data_len):
|
||||||
|
if data_len != self._ACK_DATA_LEN:
|
||||||
|
raise ValueError("Ack segment data had invalid length.")
|
||||||
|
data = stream.read(data_len)
|
||||||
|
(ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data)
|
||||||
|
try:
|
||||||
|
ack_addr = socket.inet_ntoa(ack_addr_data)
|
||||||
|
except socket.error:
|
||||||
|
raise ValueError("Ack segment data had invalid address.")
|
||||||
|
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||||
|
|
||||||
|
self._data = data
|
||||||
|
self._data_len = data_len
|
||||||
|
self._ack_msg_seq_num = ack_msg_seq_num
|
||||||
|
self._ack_master_sha = ack_master_sha
|
||||||
|
self._ack_addr = ack_addr
|
||||||
|
|
||||||
|
def ack_msg_seq_num(self):
|
||||||
|
return self._ack_msg_seq_num
|
||||||
|
|
||||||
|
def ack_master_sha(self):
|
||||||
|
return self._ack_master_sha
|
||||||
|
|
||||||
|
def ack_addr(self):
|
||||||
|
return self._ack_addr
|
||||||
|
|
||||||
class Message(object):
|
class Message(object):
|
||||||
"""Tracks an entire message object, which is composed of a number
|
"""Tracks an entire message object, which is composed of a number
|
||||||
of individual segments."""
|
of individual segments."""
|
||||||
@ -429,6 +526,53 @@ class Message(object):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _get_local_interfaces():
|
||||||
|
import array
|
||||||
|
import struct
|
||||||
|
import fcntl
|
||||||
|
import socket
|
||||||
|
|
||||||
|
max_possible = 4
|
||||||
|
bytes = max_possible * 32
|
||||||
|
SIOCGIFCONF = 0x8912
|
||||||
|
names = array.array('B', '\0' * bytes)
|
||||||
|
|
||||||
|
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
ifreq = struct.pack('iL', bytes, names.buffer_info()[0])
|
||||||
|
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq)
|
||||||
|
sockfd.close()
|
||||||
|
|
||||||
|
outbytes = struct.unpack('iL', result)[0]
|
||||||
|
namestr = names.tostring()
|
||||||
|
|
||||||
|
return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)]
|
||||||
|
|
||||||
|
def _get_local_ip_addresses():
|
||||||
|
"""Call Linux specific bits to retrieve our own IP address."""
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import fcntl
|
||||||
|
import struct
|
||||||
|
|
||||||
|
intfs = _get_local_interfaces()
|
||||||
|
|
||||||
|
ips = []
|
||||||
|
SIOCGIFADDR = 0x8915
|
||||||
|
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
for intf in intfs:
|
||||||
|
if intf == "lo":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
ifreq = (intf + '\0'*32)[:32]
|
||||||
|
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq)
|
||||||
|
addr = socket.inet_ntoa(result[20:24])
|
||||||
|
ips.append(addr)
|
||||||
|
except IOError, exc:
|
||||||
|
print "Error getting IP address: %s" % exc
|
||||||
|
sockfd.close()
|
||||||
|
return ips
|
||||||
|
|
||||||
|
|
||||||
class MostlyReliablePipe(object):
|
class MostlyReliablePipe(object):
|
||||||
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
|
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
|
||||||
delivery or receipt, just a better effort than no effort at all."""
|
delivery or receipt, just a better effort than no effort at all."""
|
||||||
@ -446,13 +590,17 @@ class MostlyReliablePipe(object):
|
|||||||
self._send_worker = 0
|
self._send_worker = 0
|
||||||
self._seq_counter = 0
|
self._seq_counter = 0
|
||||||
self._drop_prob = 0
|
self._drop_prob = 0
|
||||||
self._rt_check_worker = 0
|
self._rt_check_worker_id = 0
|
||||||
|
|
||||||
self._outgoing = []
|
self._outgoing = []
|
||||||
self._sent = {}
|
self._sent = {}
|
||||||
|
|
||||||
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
|
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
|
||||||
self._dispatched = {}
|
self._dispatched = {}
|
||||||
|
self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp
|
||||||
|
self._ack_check_worker_id = 0
|
||||||
|
|
||||||
|
self._local_ips = _get_local_ip_addresses()
|
||||||
|
|
||||||
self._setup_listener()
|
self._setup_listener()
|
||||||
self._setup_sender()
|
self._setup_sender()
|
||||||
@ -461,9 +609,12 @@ class MostlyReliablePipe(object):
|
|||||||
if self._send_worker > 0:
|
if self._send_worker > 0:
|
||||||
gobject.source_remove(self._send_worker)
|
gobject.source_remove(self._send_worker)
|
||||||
self._send_worker = 0
|
self._send_worker = 0
|
||||||
if self._rt_check_worker > 0:
|
if self._rt_check_worker_id > 0:
|
||||||
gobject.source_remove(self._rt_check_worker)
|
gobject.source_remove(self._rt_check_worker_id)
|
||||||
self._rt_check_worker = 0
|
self._rt_check_worker_id = 0
|
||||||
|
if self._ack_check_worker_id > 0:
|
||||||
|
gobject.source_remove(self._ack_check_worker_id)
|
||||||
|
self._ack_check_worker_id = 0
|
||||||
|
|
||||||
def _setup_sender(self):
|
def _setup_sender(self):
|
||||||
"""Setup the send socket for multicast."""
|
"""Setup the send socket for multicast."""
|
||||||
@ -495,7 +646,8 @@ class MostlyReliablePipe(object):
|
|||||||
# Watch the listener socket for data
|
# Watch the listener socket for data
|
||||||
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
|
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
|
||||||
gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
|
gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
|
||||||
gobject.timeout_add(50, self._retransmit_check_worker)
|
self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker)
|
||||||
|
self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker)
|
||||||
|
|
||||||
self._started = True
|
self._started = True
|
||||||
|
|
||||||
@ -516,12 +668,18 @@ class MostlyReliablePipe(object):
|
|||||||
if message.last_incoming_time() < now - self._SEGMENT_TTL:
|
if message.last_incoming_time() < now - self._SEGMENT_TTL:
|
||||||
del self._incoming[msg_key]
|
del self._incoming[msg_key]
|
||||||
|
|
||||||
# Remove already dispatched messages after a while
|
# Remove already received and dispatched messages after a while
|
||||||
for msg_key in self._dispatched.keys()[:]:
|
for msg_key in self._dispatched.keys()[:]:
|
||||||
message = self._dispatched[msg_key]
|
message = self._dispatched[msg_key]
|
||||||
if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
|
if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
|
||||||
del self._dispatched[msg_key]
|
del self._dispatched[msg_key]
|
||||||
|
|
||||||
|
# Remove received acks after a while
|
||||||
|
for ack_key in self._acks.keys()[:]:
|
||||||
|
ack_time = self._acks[ack_key]
|
||||||
|
if ack_time < now - (self._SEGMENT_TTL*2):
|
||||||
|
del self._acks[ack_key]
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
_MAX_SEGMENT_RETRIES = 10
|
_MAX_SEGMENT_RETRIES = 10
|
||||||
@ -541,6 +699,8 @@ class MostlyReliablePipe(object):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _retransmit_check_worker(self):
|
def _retransmit_check_worker(self):
|
||||||
|
"""Periodically check for and send retransmit requests for message
|
||||||
|
segments that got lost."""
|
||||||
try:
|
try:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
for key in self._incoming.keys()[:]:
|
for key in self._incoming.keys()[:]:
|
||||||
@ -582,6 +742,12 @@ class MostlyReliablePipe(object):
|
|||||||
# First segment in the message
|
# First segment in the message
|
||||||
if not self._incoming.has_key(msg_key):
|
if not self._incoming.has_key(msg_key):
|
||||||
self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
|
self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
|
||||||
|
# Acknowledge the message if it didn't come from us
|
||||||
|
if addr[0] not in self._local_ips:
|
||||||
|
print "Sending ack for msg (%s %s) from %s)" % (msg_seq_num, msg_sha, addr[0])
|
||||||
|
ack_key = (msg_seq_num, msg_sha, addr[0])
|
||||||
|
if not self._acks.has_key(ack_key):
|
||||||
|
self._send_ack_for_message(msg_seq_num, msg_sha, addr[0])
|
||||||
|
|
||||||
message = self._incoming[msg_key]
|
message = self._incoming[msg_key]
|
||||||
# Look for a dupe, and if so, drop the new segment
|
# Look for a dupe, and if so, drop the new segment
|
||||||
@ -636,10 +802,74 @@ class MostlyReliablePipe(object):
|
|||||||
next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
|
next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
|
||||||
self._schedule_segment_retransmit(key, segment, next_transmit, now)
|
self._schedule_segment_retransmit(key, segment, next_transmit, now)
|
||||||
|
|
||||||
|
def _ack_check_worker(self):
|
||||||
|
"""Periodically check for messages that haven't received an ack
|
||||||
|
yet, and retransmit them."""
|
||||||
|
try:
|
||||||
|
now = time.time()
|
||||||
|
for key in self._sent.keys()[:]:
|
||||||
|
segment = self._sent[key]
|
||||||
|
# We only care about retransmitting the first segment
|
||||||
|
# of a message, since if other machines don't have the
|
||||||
|
# rest of the segments, they'll issue retransmit requests
|
||||||
|
if segment.segment_number() != 1:
|
||||||
|
continue
|
||||||
|
if segment.last_transmit() > now - 0.150: # 150ms
|
||||||
|
# Was just retransmitted recently, wait longer
|
||||||
|
# before retransmitting it
|
||||||
|
continue
|
||||||
|
ack_key = None
|
||||||
|
for ip in self._local_ips:
|
||||||
|
ack_key = (segment.message_sequence_number(), segment.master_sha(), ip)
|
||||||
|
if self._acks.has_key(ack_key):
|
||||||
|
break
|
||||||
|
ack_key = None
|
||||||
|
# If the segment already has been acked, don't send it
|
||||||
|
# again unless somebody explicitly requests a retransmit
|
||||||
|
if ack_key is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
del self._sent[key]
|
||||||
|
self._outgoing.append(segment)
|
||||||
|
self._schedule_send_worker()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr):
|
||||||
|
"""Send an ack segment for a message."""
|
||||||
|
msg_seq_num = self._next_msg_seq()
|
||||||
|
ack = AckSegment.new_from_parts(self._remote_addr, msg_seq_num,
|
||||||
|
ack_msg_seq_num, ack_msg_sha, ack_addr)
|
||||||
|
self._outgoing.append(ack)
|
||||||
|
self._schedule_send_worker()
|
||||||
|
self._process_incoming_ack(ack)
|
||||||
|
|
||||||
|
def _process_incoming_ack(self, segment):
|
||||||
|
"""Save the ack so that we don't send an ack when we start getting the segments
|
||||||
|
the ack was acknowledging."""
|
||||||
|
# If the ack is supposed to be for a message we sent, only accept it if
|
||||||
|
# we actually sent the message to which it refers
|
||||||
|
ack_addr = segment.ack_addr()
|
||||||
|
ack_master_sha = segment.ack_master_sha()
|
||||||
|
ack_msg_seq_num = segment.ack_msg_seq_num()
|
||||||
|
if ack_addr in self._local_ips:
|
||||||
|
sent_key = (ack_msg_seq_num, ack_master_sha, 1)
|
||||||
|
if not self._sent.has_key(sent_key):
|
||||||
|
return
|
||||||
|
ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||||
|
if not self._acks.has_key(ack_key):
|
||||||
|
print "Got ack for msg (%s %s) originally from %s" % (ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||||
|
self._acks[ack_key] = time.time()
|
||||||
|
|
||||||
def set_drop_probability(self, prob=4):
|
def set_drop_probability(self, prob=4):
|
||||||
"""Debugging function to randomly drop incoming packets.
|
"""Debugging function to randomly drop incoming packets.
|
||||||
The prob argument should be an integer between 1 and 10 to drop,
|
The prob argument should be an integer between 1 and 10 to drop,
|
||||||
or 0 to drop none. Higher numbers drop more packets."""
|
or 0 to drop none. Higher numbers drop more packets."""
|
||||||
|
if type(prob) != type(1):
|
||||||
|
raise ValueError("Drop probability must be an integer.")
|
||||||
|
if prob < 1 or prob > 10:
|
||||||
|
raise ValueError("Drop probability must be between 1 and 10 inclusive.")
|
||||||
self._drop_prob = prob
|
self._drop_prob = prob
|
||||||
|
|
||||||
def _handle_incoming_data(self, source, condition):
|
def _handle_incoming_data(self, source, condition):
|
||||||
@ -665,6 +895,8 @@ class MostlyReliablePipe(object):
|
|||||||
self._process_incoming_data(segment)
|
self._process_incoming_data(segment)
|
||||||
elif stype == SegmentBase.type_retransmit():
|
elif stype == SegmentBase.type_retransmit():
|
||||||
self._process_retransmit_request(segment)
|
self._process_retransmit_request(segment)
|
||||||
|
elif stype == SegmentBase.type_ack():
|
||||||
|
self._process_incoming_ack(segment)
|
||||||
except ValueError, exc:
|
except ValueError, exc:
|
||||||
print "(MRP): Bad segment: %s" % exc
|
print "(MRP): Bad segment: %s" % exc
|
||||||
return True
|
return True
|
||||||
@ -693,12 +925,12 @@ class MostlyReliablePipe(object):
|
|||||||
nmessages = length / mtu
|
nmessages = length / mtu
|
||||||
if length % mtu > 0:
|
if length % mtu > 0:
|
||||||
nmessages = nmessages + 1
|
nmessages = nmessages + 1
|
||||||
msg_num = 1
|
seg_num = 1
|
||||||
while left > 0:
|
while left > 0:
|
||||||
seg = DataSegment.new_from_parts(msg_num, nmessages,
|
seg = DataSegment.new_from_parts(seg_num, nmessages,
|
||||||
msg_seq, master_sha, data[:mtu])
|
msg_seq, master_sha, data[:mtu])
|
||||||
self._outgoing.append(seg)
|
self._outgoing.append(seg)
|
||||||
msg_num = msg_num + 1
|
seg_num = seg_num + 1
|
||||||
data = data[mtu:]
|
data = data[mtu:]
|
||||||
left = left - mtu
|
left = left - mtu
|
||||||
self._schedule_send_worker()
|
self._schedule_send_worker()
|
||||||
|
Loading…
Reference in New Issue
Block a user