Add positive acknowledgements to work around 802.11 + multicast unreliabilities

This commit is contained in:
Dan Williams 2006-06-29 13:30:41 -04:00
parent 9ef8013a6b
commit 29984ace33

View File

@ -44,6 +44,7 @@ class SegmentBase(object):
# Message segment packet types # Message segment packet types
_SEGMENT_TYPE_DATA = 0 _SEGMENT_TYPE_DATA = 0
_SEGMENT_TYPE_RETRANSMIT = 1 _SEGMENT_TYPE_RETRANSMIT = 1
_SEGMENT_TYPE_ACK = 2
def magic(): def magic():
return SegmentBase._MAGIC return SegmentBase._MAGIC
@ -61,6 +62,10 @@ class SegmentBase(object):
return SegmentBase._SEGMENT_TYPE_RETRANSMIT return SegmentBase._SEGMENT_TYPE_RETRANSMIT
type_retransmit = staticmethod(type_retransmit) type_retransmit = staticmethod(type_retransmit)
def type_ack():
return SegmentBase._SEGMENT_TYPE_ACK
type_ack = staticmethod(type_ack)
def header_len(): def header_len():
"""Return the header size of SegmentBase packets.""" """Return the header size of SegmentBase packets."""
return SegmentBase._HEADER_LEN return SegmentBase._HEADER_LEN
@ -156,6 +161,8 @@ class SegmentBase(object):
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT: elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
elif set_type == SegmentBase._SEGMENT_TYPE_ACK:
segment = AckSegment(segno, total_segs, msg_seq_num, master_sha)
else: else:
raise ValueError("Segment has invalid type.") raise ValueError("Segment has invalid type.")
@ -319,6 +326,96 @@ class RetransmitSegment(SegmentBase):
return self._rt_segment_number return self._rt_segment_number
class AckSegment(SegmentBase):
"""A message segment that encapsulates a message acknowledgement."""
# Ack data format:
# 2: acked message sequence number
# 20: acked message total data sha1
# 4: acked message source IP address
_ACK_DATA_TEMPLATE = "! H20sI"
_ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE)
def data_template():
return AckSegment._ACK_DATA_TEMPLATE
data_template = staticmethod(data_template)
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
"""Should not be called directly."""
if segno != 1 or total_segs != 1:
raise ValueError("Acknowledgement messages must have only one segment.")
SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
self._type = SegmentBase._SEGMENT_TYPE_ACK
def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr):
# Sanity checks on the message attributes
if not ack_msg_seq_num or type(ack_msg_seq_num) != type(1):
raise ValueError("Ack message sequnce number must be an integer.")
if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535:
raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.")
if not ack_master_sha or type(ack_master_sha) != type("") or len(ack_master_sha) != 20:
raise ValueError("Ack message SHA1 checksum invalid.")
if type(ack_addr) != type(""):
raise ValueError("Ack message invalid address.")
try:
foo = socket.inet_aton(ack_addr)
except socket.error:
raise ValueError("Ack message invalid address.")
_verify_data = staticmethod(_verify_data)
def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr):
"""Pack an ack payload."""
addr_data = socket.inet_aton(ack_addr)
data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num,
ack_master_sha, addr_data)
return (data, _sha_data(data))
_make_ack_data = staticmethod(_make_ack_data)
def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr):
"""Static constructor for creation from individual attributes."""
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
(data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num,
ack_master_sha, ack_addr)
segment = AckSegment(1, 1, msg_seq_num, data_sha)
segment._data_len = AckSegment._ACK_DATA_LEN
segment._data = data
SegmentBase._validate_address(addr)
segment._addr = addr
segment._ack_msg_seq_num = ack_msg_seq_num
segment._ack_master_sha = ack_master_sha
segment._ack_addr = ack_addr
return segment
new_from_parts = staticmethod(new_from_parts)
def _unpack_data(self, stream, data_len):
if data_len != self._ACK_DATA_LEN:
raise ValueError("Ack segment data had invalid length.")
data = stream.read(data_len)
(ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data)
try:
ack_addr = socket.inet_ntoa(ack_addr_data)
except socket.error:
raise ValueError("Ack segment data had invalid address.")
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
self._data = data
self._data_len = data_len
self._ack_msg_seq_num = ack_msg_seq_num
self._ack_master_sha = ack_master_sha
self._ack_addr = ack_addr
def ack_msg_seq_num(self):
return self._ack_msg_seq_num
def ack_master_sha(self):
return self._ack_master_sha
def ack_addr(self):
return self._ack_addr
class Message(object): class Message(object):
"""Tracks an entire message object, which is composed of a number """Tracks an entire message object, which is composed of a number
of individual segments.""" of individual segments."""
@ -429,6 +526,53 @@ class Message(object):
return 0 return 0
def _get_local_interfaces():
import array
import struct
import fcntl
import socket
max_possible = 4
bytes = max_possible * 32
SIOCGIFCONF = 0x8912
names = array.array('B', '\0' * bytes)
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ifreq = struct.pack('iL', bytes, names.buffer_info()[0])
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq)
sockfd.close()
outbytes = struct.unpack('iL', result)[0]
namestr = names.tostring()
return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)]
def _get_local_ip_addresses():
"""Call Linux specific bits to retrieve our own IP address."""
import socket
import sys
import fcntl
import struct
intfs = _get_local_interfaces()
ips = []
SIOCGIFADDR = 0x8915
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
for intf in intfs:
if intf == "lo":
continue
try:
ifreq = (intf + '\0'*32)[:32]
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq)
addr = socket.inet_ntoa(result[20:24])
ips.append(addr)
except IOError, exc:
print "Error getting IP address: %s" % exc
sockfd.close()
return ips
class MostlyReliablePipe(object): class MostlyReliablePipe(object):
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing """Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
delivery or receipt, just a better effort than no effort at all.""" delivery or receipt, just a better effort than no effort at all."""
@ -446,13 +590,17 @@ class MostlyReliablePipe(object):
self._send_worker = 0 self._send_worker = 0
self._seq_counter = 0 self._seq_counter = 0
self._drop_prob = 0 self._drop_prob = 0
self._rt_check_worker = 0 self._rt_check_worker_id = 0
self._outgoing = [] self._outgoing = []
self._sent = {} self._sent = {}
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...] self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
self._dispatched = {} self._dispatched = {}
self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp
self._ack_check_worker_id = 0
self._local_ips = _get_local_ip_addresses()
self._setup_listener() self._setup_listener()
self._setup_sender() self._setup_sender()
@ -461,9 +609,12 @@ class MostlyReliablePipe(object):
if self._send_worker > 0: if self._send_worker > 0:
gobject.source_remove(self._send_worker) gobject.source_remove(self._send_worker)
self._send_worker = 0 self._send_worker = 0
if self._rt_check_worker > 0: if self._rt_check_worker_id > 0:
gobject.source_remove(self._rt_check_worker) gobject.source_remove(self._rt_check_worker_id)
self._rt_check_worker = 0 self._rt_check_worker_id = 0
if self._ack_check_worker_id > 0:
gobject.source_remove(self._ack_check_worker_id)
self._ack_check_worker_id = 0
def _setup_sender(self): def _setup_sender(self):
"""Setup the send socket for multicast.""" """Setup the send socket for multicast."""
@ -495,7 +646,8 @@ class MostlyReliablePipe(object):
# Watch the listener socket for data # Watch the listener socket for data
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker) gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
gobject.timeout_add(50, self._retransmit_check_worker) self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker)
self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker)
self._started = True self._started = True
@ -516,12 +668,18 @@ class MostlyReliablePipe(object):
if message.last_incoming_time() < now - self._SEGMENT_TTL: if message.last_incoming_time() < now - self._SEGMENT_TTL:
del self._incoming[msg_key] del self._incoming[msg_key]
# Remove already dispatched messages after a while # Remove already received and dispatched messages after a while
for msg_key in self._dispatched.keys()[:]: for msg_key in self._dispatched.keys()[:]:
message = self._dispatched[msg_key] message = self._dispatched[msg_key]
if message.dispatch_time() < now - (self._SEGMENT_TTL*2): if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
del self._dispatched[msg_key] del self._dispatched[msg_key]
# Remove received acks after a while
for ack_key in self._acks.keys()[:]:
ack_time = self._acks[ack_key]
if ack_time < now - (self._SEGMENT_TTL*2):
del self._acks[ack_key]
return True return True
_MAX_SEGMENT_RETRIES = 10 _MAX_SEGMENT_RETRIES = 10
@ -541,6 +699,8 @@ class MostlyReliablePipe(object):
return False return False
def _retransmit_check_worker(self): def _retransmit_check_worker(self):
"""Periodically check for and send retransmit requests for message
segments that got lost."""
try: try:
now = time.time() now = time.time()
for key in self._incoming.keys()[:]: for key in self._incoming.keys()[:]:
@ -582,6 +742,12 @@ class MostlyReliablePipe(object):
# First segment in the message # First segment in the message
if not self._incoming.has_key(msg_key): if not self._incoming.has_key(msg_key):
self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs) self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
# Acknowledge the message if it didn't come from us
if addr[0] not in self._local_ips:
print "Sending ack for msg (%s %s) from %s)" % (msg_seq_num, msg_sha, addr[0])
ack_key = (msg_seq_num, msg_sha, addr[0])
if not self._acks.has_key(ack_key):
self._send_ack_for_message(msg_seq_num, msg_sha, addr[0])
message = self._incoming[msg_key] message = self._incoming[msg_key]
# Look for a dupe, and if so, drop the new segment # Look for a dupe, and if so, drop the new segment
@ -636,10 +802,74 @@ class MostlyReliablePipe(object):
next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL) next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
self._schedule_segment_retransmit(key, segment, next_transmit, now) self._schedule_segment_retransmit(key, segment, next_transmit, now)
def _ack_check_worker(self):
"""Periodically check for messages that haven't received an ack
yet, and retransmit them."""
try:
now = time.time()
for key in self._sent.keys()[:]:
segment = self._sent[key]
# We only care about retransmitting the first segment
# of a message, since if other machines don't have the
# rest of the segments, they'll issue retransmit requests
if segment.segment_number() != 1:
continue
if segment.last_transmit() > now - 0.150: # 150ms
# Was just retransmitted recently, wait longer
# before retransmitting it
continue
ack_key = None
for ip in self._local_ips:
ack_key = (segment.message_sequence_number(), segment.master_sha(), ip)
if self._acks.has_key(ack_key):
break
ack_key = None
# If the segment already has been acked, don't send it
# again unless somebody explicitly requests a retransmit
if ack_key is not None:
continue
del self._sent[key]
self._outgoing.append(segment)
self._schedule_send_worker()
except KeyboardInterrupt:
return False
return True
def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr):
"""Send an ack segment for a message."""
msg_seq_num = self._next_msg_seq()
ack = AckSegment.new_from_parts(self._remote_addr, msg_seq_num,
ack_msg_seq_num, ack_msg_sha, ack_addr)
self._outgoing.append(ack)
self._schedule_send_worker()
self._process_incoming_ack(ack)
def _process_incoming_ack(self, segment):
"""Save the ack so that we don't send an ack when we start getting the segments
the ack was acknowledging."""
# If the ack is supposed to be for a message we sent, only accept it if
# we actually sent the message to which it refers
ack_addr = segment.ack_addr()
ack_master_sha = segment.ack_master_sha()
ack_msg_seq_num = segment.ack_msg_seq_num()
if ack_addr in self._local_ips:
sent_key = (ack_msg_seq_num, ack_master_sha, 1)
if not self._sent.has_key(sent_key):
return
ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr)
if not self._acks.has_key(ack_key):
print "Got ack for msg (%s %s) originally from %s" % (ack_msg_seq_num, ack_master_sha, ack_addr)
self._acks[ack_key] = time.time()
def set_drop_probability(self, prob=4): def set_drop_probability(self, prob=4):
"""Debugging function to randomly drop incoming packets. """Debugging function to randomly drop incoming packets.
The prob argument should be an integer between 1 and 10 to drop, The prob argument should be an integer between 1 and 10 to drop,
or 0 to drop none. Higher numbers drop more packets.""" or 0 to drop none. Higher numbers drop more packets."""
if type(prob) != type(1):
raise ValueError("Drop probability must be an integer.")
if prob < 1 or prob > 10:
raise ValueError("Drop probability must be between 1 and 10 inclusive.")
self._drop_prob = prob self._drop_prob = prob
def _handle_incoming_data(self, source, condition): def _handle_incoming_data(self, source, condition):
@ -665,6 +895,8 @@ class MostlyReliablePipe(object):
self._process_incoming_data(segment) self._process_incoming_data(segment)
elif stype == SegmentBase.type_retransmit(): elif stype == SegmentBase.type_retransmit():
self._process_retransmit_request(segment) self._process_retransmit_request(segment)
elif stype == SegmentBase.type_ack():
self._process_incoming_ack(segment)
except ValueError, exc: except ValueError, exc:
print "(MRP): Bad segment: %s" % exc print "(MRP): Bad segment: %s" % exc
return True return True
@ -693,12 +925,12 @@ class MostlyReliablePipe(object):
nmessages = length / mtu nmessages = length / mtu
if length % mtu > 0: if length % mtu > 0:
nmessages = nmessages + 1 nmessages = nmessages + 1
msg_num = 1 seg_num = 1
while left > 0: while left > 0:
seg = DataSegment.new_from_parts(msg_num, nmessages, seg = DataSegment.new_from_parts(seg_num, nmessages,
msg_seq, master_sha, data[:mtu]) msg_seq, master_sha, data[:mtu])
self._outgoing.append(seg) self._outgoing.append(seg)
msg_num = msg_num + 1 seg_num = seg_num + 1
data = data[mtu:] data = data[mtu:]
left = left - mtu left = left - mtu
self._schedule_send_worker() self._schedule_send_worker()