Add positive acknowledgements to work around 802.11 + multicast unreliabilities
This commit is contained in:
parent
9ef8013a6b
commit
29984ace33
@ -44,6 +44,7 @@ class SegmentBase(object):
|
||||
# Message segment packet types
|
||||
_SEGMENT_TYPE_DATA = 0
|
||||
_SEGMENT_TYPE_RETRANSMIT = 1
|
||||
_SEGMENT_TYPE_ACK = 2
|
||||
|
||||
def magic():
|
||||
return SegmentBase._MAGIC
|
||||
@ -61,6 +62,10 @@ class SegmentBase(object):
|
||||
return SegmentBase._SEGMENT_TYPE_RETRANSMIT
|
||||
type_retransmit = staticmethod(type_retransmit)
|
||||
|
||||
def type_ack():
|
||||
return SegmentBase._SEGMENT_TYPE_ACK
|
||||
type_ack = staticmethod(type_ack)
|
||||
|
||||
def header_len():
|
||||
"""Return the header size of SegmentBase packets."""
|
||||
return SegmentBase._HEADER_LEN
|
||||
@ -156,6 +161,8 @@ class SegmentBase(object):
|
||||
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
|
||||
segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||
elif set_type == SegmentBase._SEGMENT_TYPE_ACK:
|
||||
segment = AckSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||
else:
|
||||
raise ValueError("Segment has invalid type.")
|
||||
|
||||
@ -319,6 +326,96 @@ class RetransmitSegment(SegmentBase):
|
||||
return self._rt_segment_number
|
||||
|
||||
|
||||
class AckSegment(SegmentBase):
|
||||
"""A message segment that encapsulates a message acknowledgement."""
|
||||
|
||||
# Ack data format:
|
||||
# 2: acked message sequence number
|
||||
# 20: acked message total data sha1
|
||||
# 4: acked message source IP address
|
||||
_ACK_DATA_TEMPLATE = "! H20sI"
|
||||
_ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE)
|
||||
|
||||
def data_template():
|
||||
return AckSegment._ACK_DATA_TEMPLATE
|
||||
data_template = staticmethod(data_template)
|
||||
|
||||
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
||||
"""Should not be called directly."""
|
||||
if segno != 1 or total_segs != 1:
|
||||
raise ValueError("Acknowledgement messages must have only one segment.")
|
||||
|
||||
SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
|
||||
self._type = SegmentBase._SEGMENT_TYPE_ACK
|
||||
|
||||
def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr):
|
||||
# Sanity checks on the message attributes
|
||||
if not ack_msg_seq_num or type(ack_msg_seq_num) != type(1):
|
||||
raise ValueError("Ack message sequnce number must be an integer.")
|
||||
if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535:
|
||||
raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.")
|
||||
if not ack_master_sha or type(ack_master_sha) != type("") or len(ack_master_sha) != 20:
|
||||
raise ValueError("Ack message SHA1 checksum invalid.")
|
||||
if type(ack_addr) != type(""):
|
||||
raise ValueError("Ack message invalid address.")
|
||||
try:
|
||||
foo = socket.inet_aton(ack_addr)
|
||||
except socket.error:
|
||||
raise ValueError("Ack message invalid address.")
|
||||
_verify_data = staticmethod(_verify_data)
|
||||
|
||||
def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr):
|
||||
"""Pack an ack payload."""
|
||||
addr_data = socket.inet_aton(ack_addr)
|
||||
data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num,
|
||||
ack_master_sha, addr_data)
|
||||
return (data, _sha_data(data))
|
||||
_make_ack_data = staticmethod(_make_ack_data)
|
||||
|
||||
def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr):
|
||||
"""Static constructor for creation from individual attributes."""
|
||||
|
||||
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||
(data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num,
|
||||
ack_master_sha, ack_addr)
|
||||
segment = AckSegment(1, 1, msg_seq_num, data_sha)
|
||||
segment._data_len = AckSegment._ACK_DATA_LEN
|
||||
segment._data = data
|
||||
SegmentBase._validate_address(addr)
|
||||
segment._addr = addr
|
||||
|
||||
segment._ack_msg_seq_num = ack_msg_seq_num
|
||||
segment._ack_master_sha = ack_master_sha
|
||||
segment._ack_addr = ack_addr
|
||||
return segment
|
||||
new_from_parts = staticmethod(new_from_parts)
|
||||
|
||||
def _unpack_data(self, stream, data_len):
|
||||
if data_len != self._ACK_DATA_LEN:
|
||||
raise ValueError("Ack segment data had invalid length.")
|
||||
data = stream.read(data_len)
|
||||
(ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data)
|
||||
try:
|
||||
ack_addr = socket.inet_ntoa(ack_addr_data)
|
||||
except socket.error:
|
||||
raise ValueError("Ack segment data had invalid address.")
|
||||
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||
|
||||
self._data = data
|
||||
self._data_len = data_len
|
||||
self._ack_msg_seq_num = ack_msg_seq_num
|
||||
self._ack_master_sha = ack_master_sha
|
||||
self._ack_addr = ack_addr
|
||||
|
||||
def ack_msg_seq_num(self):
|
||||
return self._ack_msg_seq_num
|
||||
|
||||
def ack_master_sha(self):
|
||||
return self._ack_master_sha
|
||||
|
||||
def ack_addr(self):
|
||||
return self._ack_addr
|
||||
|
||||
class Message(object):
|
||||
"""Tracks an entire message object, which is composed of a number
|
||||
of individual segments."""
|
||||
@ -429,6 +526,53 @@ class Message(object):
|
||||
return 0
|
||||
|
||||
|
||||
def _get_local_interfaces():
|
||||
import array
|
||||
import struct
|
||||
import fcntl
|
||||
import socket
|
||||
|
||||
max_possible = 4
|
||||
bytes = max_possible * 32
|
||||
SIOCGIFCONF = 0x8912
|
||||
names = array.array('B', '\0' * bytes)
|
||||
|
||||
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
ifreq = struct.pack('iL', bytes, names.buffer_info()[0])
|
||||
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq)
|
||||
sockfd.close()
|
||||
|
||||
outbytes = struct.unpack('iL', result)[0]
|
||||
namestr = names.tostring()
|
||||
|
||||
return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)]
|
||||
|
||||
def _get_local_ip_addresses():
|
||||
"""Call Linux specific bits to retrieve our own IP address."""
|
||||
import socket
|
||||
import sys
|
||||
import fcntl
|
||||
import struct
|
||||
|
||||
intfs = _get_local_interfaces()
|
||||
|
||||
ips = []
|
||||
SIOCGIFADDR = 0x8915
|
||||
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
for intf in intfs:
|
||||
if intf == "lo":
|
||||
continue
|
||||
try:
|
||||
ifreq = (intf + '\0'*32)[:32]
|
||||
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq)
|
||||
addr = socket.inet_ntoa(result[20:24])
|
||||
ips.append(addr)
|
||||
except IOError, exc:
|
||||
print "Error getting IP address: %s" % exc
|
||||
sockfd.close()
|
||||
return ips
|
||||
|
||||
|
||||
class MostlyReliablePipe(object):
|
||||
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
|
||||
delivery or receipt, just a better effort than no effort at all."""
|
||||
@ -446,13 +590,17 @@ class MostlyReliablePipe(object):
|
||||
self._send_worker = 0
|
||||
self._seq_counter = 0
|
||||
self._drop_prob = 0
|
||||
self._rt_check_worker = 0
|
||||
self._rt_check_worker_id = 0
|
||||
|
||||
self._outgoing = []
|
||||
self._sent = {}
|
||||
|
||||
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
|
||||
self._dispatched = {}
|
||||
self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp
|
||||
self._ack_check_worker_id = 0
|
||||
|
||||
self._local_ips = _get_local_ip_addresses()
|
||||
|
||||
self._setup_listener()
|
||||
self._setup_sender()
|
||||
@ -461,9 +609,12 @@ class MostlyReliablePipe(object):
|
||||
if self._send_worker > 0:
|
||||
gobject.source_remove(self._send_worker)
|
||||
self._send_worker = 0
|
||||
if self._rt_check_worker > 0:
|
||||
gobject.source_remove(self._rt_check_worker)
|
||||
self._rt_check_worker = 0
|
||||
if self._rt_check_worker_id > 0:
|
||||
gobject.source_remove(self._rt_check_worker_id)
|
||||
self._rt_check_worker_id = 0
|
||||
if self._ack_check_worker_id > 0:
|
||||
gobject.source_remove(self._ack_check_worker_id)
|
||||
self._ack_check_worker_id = 0
|
||||
|
||||
def _setup_sender(self):
|
||||
"""Setup the send socket for multicast."""
|
||||
@ -495,7 +646,8 @@ class MostlyReliablePipe(object):
|
||||
# Watch the listener socket for data
|
||||
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
|
||||
gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
|
||||
gobject.timeout_add(50, self._retransmit_check_worker)
|
||||
self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker)
|
||||
self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker)
|
||||
|
||||
self._started = True
|
||||
|
||||
@ -516,12 +668,18 @@ class MostlyReliablePipe(object):
|
||||
if message.last_incoming_time() < now - self._SEGMENT_TTL:
|
||||
del self._incoming[msg_key]
|
||||
|
||||
# Remove already dispatched messages after a while
|
||||
# Remove already received and dispatched messages after a while
|
||||
for msg_key in self._dispatched.keys()[:]:
|
||||
message = self._dispatched[msg_key]
|
||||
if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
|
||||
del self._dispatched[msg_key]
|
||||
|
||||
# Remove received acks after a while
|
||||
for ack_key in self._acks.keys()[:]:
|
||||
ack_time = self._acks[ack_key]
|
||||
if ack_time < now - (self._SEGMENT_TTL*2):
|
||||
del self._acks[ack_key]
|
||||
|
||||
return True
|
||||
|
||||
_MAX_SEGMENT_RETRIES = 10
|
||||
@ -541,6 +699,8 @@ class MostlyReliablePipe(object):
|
||||
return False
|
||||
|
||||
def _retransmit_check_worker(self):
|
||||
"""Periodically check for and send retransmit requests for message
|
||||
segments that got lost."""
|
||||
try:
|
||||
now = time.time()
|
||||
for key in self._incoming.keys()[:]:
|
||||
@ -582,6 +742,12 @@ class MostlyReliablePipe(object):
|
||||
# First segment in the message
|
||||
if not self._incoming.has_key(msg_key):
|
||||
self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
|
||||
# Acknowledge the message if it didn't come from us
|
||||
if addr[0] not in self._local_ips:
|
||||
print "Sending ack for msg (%s %s) from %s)" % (msg_seq_num, msg_sha, addr[0])
|
||||
ack_key = (msg_seq_num, msg_sha, addr[0])
|
||||
if not self._acks.has_key(ack_key):
|
||||
self._send_ack_for_message(msg_seq_num, msg_sha, addr[0])
|
||||
|
||||
message = self._incoming[msg_key]
|
||||
# Look for a dupe, and if so, drop the new segment
|
||||
@ -636,10 +802,74 @@ class MostlyReliablePipe(object):
|
||||
next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
|
||||
self._schedule_segment_retransmit(key, segment, next_transmit, now)
|
||||
|
||||
def _ack_check_worker(self):
|
||||
"""Periodically check for messages that haven't received an ack
|
||||
yet, and retransmit them."""
|
||||
try:
|
||||
now = time.time()
|
||||
for key in self._sent.keys()[:]:
|
||||
segment = self._sent[key]
|
||||
# We only care about retransmitting the first segment
|
||||
# of a message, since if other machines don't have the
|
||||
# rest of the segments, they'll issue retransmit requests
|
||||
if segment.segment_number() != 1:
|
||||
continue
|
||||
if segment.last_transmit() > now - 0.150: # 150ms
|
||||
# Was just retransmitted recently, wait longer
|
||||
# before retransmitting it
|
||||
continue
|
||||
ack_key = None
|
||||
for ip in self._local_ips:
|
||||
ack_key = (segment.message_sequence_number(), segment.master_sha(), ip)
|
||||
if self._acks.has_key(ack_key):
|
||||
break
|
||||
ack_key = None
|
||||
# If the segment already has been acked, don't send it
|
||||
# again unless somebody explicitly requests a retransmit
|
||||
if ack_key is not None:
|
||||
continue
|
||||
|
||||
del self._sent[key]
|
||||
self._outgoing.append(segment)
|
||||
self._schedule_send_worker()
|
||||
except KeyboardInterrupt:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr):
|
||||
"""Send an ack segment for a message."""
|
||||
msg_seq_num = self._next_msg_seq()
|
||||
ack = AckSegment.new_from_parts(self._remote_addr, msg_seq_num,
|
||||
ack_msg_seq_num, ack_msg_sha, ack_addr)
|
||||
self._outgoing.append(ack)
|
||||
self._schedule_send_worker()
|
||||
self._process_incoming_ack(ack)
|
||||
|
||||
def _process_incoming_ack(self, segment):
|
||||
"""Save the ack so that we don't send an ack when we start getting the segments
|
||||
the ack was acknowledging."""
|
||||
# If the ack is supposed to be for a message we sent, only accept it if
|
||||
# we actually sent the message to which it refers
|
||||
ack_addr = segment.ack_addr()
|
||||
ack_master_sha = segment.ack_master_sha()
|
||||
ack_msg_seq_num = segment.ack_msg_seq_num()
|
||||
if ack_addr in self._local_ips:
|
||||
sent_key = (ack_msg_seq_num, ack_master_sha, 1)
|
||||
if not self._sent.has_key(sent_key):
|
||||
return
|
||||
ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||
if not self._acks.has_key(ack_key):
|
||||
print "Got ack for msg (%s %s) originally from %s" % (ack_msg_seq_num, ack_master_sha, ack_addr)
|
||||
self._acks[ack_key] = time.time()
|
||||
|
||||
def set_drop_probability(self, prob=4):
|
||||
"""Debugging function to randomly drop incoming packets.
|
||||
The prob argument should be an integer between 1 and 10 to drop,
|
||||
or 0 to drop none. Higher numbers drop more packets."""
|
||||
if type(prob) != type(1):
|
||||
raise ValueError("Drop probability must be an integer.")
|
||||
if prob < 1 or prob > 10:
|
||||
raise ValueError("Drop probability must be between 1 and 10 inclusive.")
|
||||
self._drop_prob = prob
|
||||
|
||||
def _handle_incoming_data(self, source, condition):
|
||||
@ -665,6 +895,8 @@ class MostlyReliablePipe(object):
|
||||
self._process_incoming_data(segment)
|
||||
elif stype == SegmentBase.type_retransmit():
|
||||
self._process_retransmit_request(segment)
|
||||
elif stype == SegmentBase.type_ack():
|
||||
self._process_incoming_ack(segment)
|
||||
except ValueError, exc:
|
||||
print "(MRP): Bad segment: %s" % exc
|
||||
return True
|
||||
@ -693,12 +925,12 @@ class MostlyReliablePipe(object):
|
||||
nmessages = length / mtu
|
||||
if length % mtu > 0:
|
||||
nmessages = nmessages + 1
|
||||
msg_num = 1
|
||||
seg_num = 1
|
||||
while left > 0:
|
||||
seg = DataSegment.new_from_parts(msg_num, nmessages,
|
||||
seg = DataSegment.new_from_parts(seg_num, nmessages,
|
||||
msg_seq, master_sha, data[:mtu])
|
||||
self._outgoing.append(seg)
|
||||
msg_num = msg_num + 1
|
||||
seg_num = seg_num + 1
|
||||
data = data[mtu:]
|
||||
left = left - mtu
|
||||
self._schedule_send_worker()
|
||||
|
Loading…
Reference in New Issue
Block a user