diff --git a/sugar/p2p/MostlyReliablePipe.py b/sugar/p2p/MostlyReliablePipe.py index 5b503c2a..6523bc0e 100644 --- a/sugar/p2p/MostlyReliablePipe.py +++ b/sugar/p2p/MostlyReliablePipe.py @@ -44,6 +44,7 @@ class SegmentBase(object): # Message segment packet types _SEGMENT_TYPE_DATA = 0 _SEGMENT_TYPE_RETRANSMIT = 1 + _SEGMENT_TYPE_ACK = 2 def magic(): return SegmentBase._MAGIC @@ -61,6 +62,10 @@ class SegmentBase(object): return SegmentBase._SEGMENT_TYPE_RETRANSMIT type_retransmit = staticmethod(type_retransmit) + def type_ack(): + return SegmentBase._SEGMENT_TYPE_ACK + type_ack = staticmethod(type_ack) + def header_len(): """Return the header size of SegmentBase packets.""" return SegmentBase._HEADER_LEN @@ -156,6 +161,8 @@ class SegmentBase(object): segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT: segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) + elif set_type == SegmentBase._SEGMENT_TYPE_ACK: + segment = AckSegment(segno, total_segs, msg_seq_num, master_sha) else: raise ValueError("Segment has invalid type.") @@ -319,6 +326,96 @@ class RetransmitSegment(SegmentBase): return self._rt_segment_number +class AckSegment(SegmentBase): + """A message segment that encapsulates a message acknowledgement.""" + + # Ack data format: + # 2: acked message sequence number + # 20: acked message total data sha1 + # 4: acked message source IP address + _ACK_DATA_TEMPLATE = "! H20sI" + _ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE) + + def data_template(): + return AckSegment._ACK_DATA_TEMPLATE + data_template = staticmethod(data_template) + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + """Should not be called directly.""" + if segno != 1 or total_segs != 1: + raise ValueError("Acknowledgement messages must have only one segment.") + + SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) + self._type = SegmentBase._SEGMENT_TYPE_ACK + + def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr): + # Sanity checks on the message attributes + if not ack_msg_seq_num or type(ack_msg_seq_num) != type(1): + raise ValueError("Ack message sequnce number must be an integer.") + if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535: + raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.") + if not ack_master_sha or type(ack_master_sha) != type("") or len(ack_master_sha) != 20: + raise ValueError("Ack message SHA1 checksum invalid.") + if type(ack_addr) != type(""): + raise ValueError("Ack message invalid address.") + try: + foo = socket.inet_aton(ack_addr) + except socket.error: + raise ValueError("Ack message invalid address.") + _verify_data = staticmethod(_verify_data) + + def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr): + """Pack an ack payload.""" + addr_data = socket.inet_aton(ack_addr) + data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num, + ack_master_sha, addr_data) + return (data, _sha_data(data)) + _make_ack_data = staticmethod(_make_ack_data) + + def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr): + """Static constructor for creation from individual attributes.""" + + AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr) + (data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num, + ack_master_sha, ack_addr) + segment = AckSegment(1, 1, msg_seq_num, data_sha) + segment._data_len = AckSegment._ACK_DATA_LEN + segment._data = data + SegmentBase._validate_address(addr) + segment._addr = addr + + segment._ack_msg_seq_num = ack_msg_seq_num + segment._ack_master_sha = ack_master_sha + segment._ack_addr = ack_addr + return segment + new_from_parts = staticmethod(new_from_parts) + + def _unpack_data(self, stream, data_len): + if data_len != self._ACK_DATA_LEN: + raise ValueError("Ack segment data had invalid length.") + data = stream.read(data_len) + (ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data) + try: + ack_addr = socket.inet_ntoa(ack_addr_data) + except socket.error: + raise ValueError("Ack segment data had invalid address.") + AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr) + + self._data = data + self._data_len = data_len + self._ack_msg_seq_num = ack_msg_seq_num + self._ack_master_sha = ack_master_sha + self._ack_addr = ack_addr + + def ack_msg_seq_num(self): + return self._ack_msg_seq_num + + def ack_master_sha(self): + return self._ack_master_sha + + def ack_addr(self): + return self._ack_addr + class Message(object): """Tracks an entire message object, which is composed of a number of individual segments.""" @@ -429,6 +526,53 @@ class Message(object): return 0 +def _get_local_interfaces(): + import array + import struct + import fcntl + import socket + + max_possible = 4 + bytes = max_possible * 32 + SIOCGIFCONF = 0x8912 + names = array.array('B', '\0' * bytes) + + sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ifreq = struct.pack('iL', bytes, names.buffer_info()[0]) + result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq) + sockfd.close() + + outbytes = struct.unpack('iL', result)[0] + namestr = names.tostring() + + return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)] + +def _get_local_ip_addresses(): + """Call Linux specific bits to retrieve our own IP address.""" + import socket + import sys + import fcntl + import struct + + intfs = _get_local_interfaces() + + ips = [] + SIOCGIFADDR = 0x8915 + sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + for intf in intfs: + if intf == "lo": + continue + try: + ifreq = (intf + '\0'*32)[:32] + result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq) + addr = socket.inet_ntoa(result[20:24]) + ips.append(addr) + except IOError, exc: + print "Error getting IP address: %s" % exc + sockfd.close() + return ips + + class MostlyReliablePipe(object): """Implement Mostly-Reliable UDP. We don't actually care about guaranteeing delivery or receipt, just a better effort than no effort at all.""" @@ -446,13 +590,17 @@ class MostlyReliablePipe(object): self._send_worker = 0 self._seq_counter = 0 self._drop_prob = 0 - self._rt_check_worker = 0 + self._rt_check_worker_id = 0 self._outgoing = [] self._sent = {} self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...] self._dispatched = {} + self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp + self._ack_check_worker_id = 0 + + self._local_ips = _get_local_ip_addresses() self._setup_listener() self._setup_sender() @@ -461,9 +609,12 @@ class MostlyReliablePipe(object): if self._send_worker > 0: gobject.source_remove(self._send_worker) self._send_worker = 0 - if self._rt_check_worker > 0: - gobject.source_remove(self._rt_check_worker) - self._rt_check_worker = 0 + if self._rt_check_worker_id > 0: + gobject.source_remove(self._rt_check_worker_id) + self._rt_check_worker_id = 0 + if self._ack_check_worker_id > 0: + gobject.source_remove(self._ack_check_worker_id) + self._ack_check_worker_id = 0 def _setup_sender(self): """Setup the send socket for multicast.""" @@ -495,7 +646,8 @@ class MostlyReliablePipe(object): # Watch the listener socket for data gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker) - gobject.timeout_add(50, self._retransmit_check_worker) + self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker) + self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker) self._started = True @@ -516,12 +668,18 @@ class MostlyReliablePipe(object): if message.last_incoming_time() < now - self._SEGMENT_TTL: del self._incoming[msg_key] - # Remove already dispatched messages after a while + # Remove already received and dispatched messages after a while for msg_key in self._dispatched.keys()[:]: message = self._dispatched[msg_key] if message.dispatch_time() < now - (self._SEGMENT_TTL*2): del self._dispatched[msg_key] + # Remove received acks after a while + for ack_key in self._acks.keys()[:]: + ack_time = self._acks[ack_key] + if ack_time < now - (self._SEGMENT_TTL*2): + del self._acks[ack_key] + return True _MAX_SEGMENT_RETRIES = 10 @@ -541,6 +699,8 @@ class MostlyReliablePipe(object): return False def _retransmit_check_worker(self): + """Periodically check for and send retransmit requests for message + segments that got lost.""" try: now = time.time() for key in self._incoming.keys()[:]: @@ -582,6 +742,12 @@ class MostlyReliablePipe(object): # First segment in the message if not self._incoming.has_key(msg_key): self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs) + # Acknowledge the message if it didn't come from us + if addr[0] not in self._local_ips: + print "Sending ack for msg (%s %s) from %s)" % (msg_seq_num, msg_sha, addr[0]) + ack_key = (msg_seq_num, msg_sha, addr[0]) + if not self._acks.has_key(ack_key): + self._send_ack_for_message(msg_seq_num, msg_sha, addr[0]) message = self._incoming[msg_key] # Look for a dupe, and if so, drop the new segment @@ -636,10 +802,74 @@ class MostlyReliablePipe(object): next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL) self._schedule_segment_retransmit(key, segment, next_transmit, now) + def _ack_check_worker(self): + """Periodically check for messages that haven't received an ack + yet, and retransmit them.""" + try: + now = time.time() + for key in self._sent.keys()[:]: + segment = self._sent[key] + # We only care about retransmitting the first segment + # of a message, since if other machines don't have the + # rest of the segments, they'll issue retransmit requests + if segment.segment_number() != 1: + continue + if segment.last_transmit() > now - 0.150: # 150ms + # Was just retransmitted recently, wait longer + # before retransmitting it + continue + ack_key = None + for ip in self._local_ips: + ack_key = (segment.message_sequence_number(), segment.master_sha(), ip) + if self._acks.has_key(ack_key): + break + ack_key = None + # If the segment already has been acked, don't send it + # again unless somebody explicitly requests a retransmit + if ack_key is not None: + continue + + del self._sent[key] + self._outgoing.append(segment) + self._schedule_send_worker() + except KeyboardInterrupt: + return False + return True + + def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr): + """Send an ack segment for a message.""" + msg_seq_num = self._next_msg_seq() + ack = AckSegment.new_from_parts(self._remote_addr, msg_seq_num, + ack_msg_seq_num, ack_msg_sha, ack_addr) + self._outgoing.append(ack) + self._schedule_send_worker() + self._process_incoming_ack(ack) + + def _process_incoming_ack(self, segment): + """Save the ack so that we don't send an ack when we start getting the segments + the ack was acknowledging.""" + # If the ack is supposed to be for a message we sent, only accept it if + # we actually sent the message to which it refers + ack_addr = segment.ack_addr() + ack_master_sha = segment.ack_master_sha() + ack_msg_seq_num = segment.ack_msg_seq_num() + if ack_addr in self._local_ips: + sent_key = (ack_msg_seq_num, ack_master_sha, 1) + if not self._sent.has_key(sent_key): + return + ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr) + if not self._acks.has_key(ack_key): + print "Got ack for msg (%s %s) originally from %s" % (ack_msg_seq_num, ack_master_sha, ack_addr) + self._acks[ack_key] = time.time() + def set_drop_probability(self, prob=4): """Debugging function to randomly drop incoming packets. The prob argument should be an integer between 1 and 10 to drop, or 0 to drop none. Higher numbers drop more packets.""" + if type(prob) != type(1): + raise ValueError("Drop probability must be an integer.") + if prob < 1 or prob > 10: + raise ValueError("Drop probability must be between 1 and 10 inclusive.") self._drop_prob = prob def _handle_incoming_data(self, source, condition): @@ -665,6 +895,8 @@ class MostlyReliablePipe(object): self._process_incoming_data(segment) elif stype == SegmentBase.type_retransmit(): self._process_retransmit_request(segment) + elif stype == SegmentBase.type_ack(): + self._process_incoming_ack(segment) except ValueError, exc: print "(MRP): Bad segment: %s" % exc return True @@ -693,12 +925,12 @@ class MostlyReliablePipe(object): nmessages = length / mtu if length % mtu > 0: nmessages = nmessages + 1 - msg_num = 1 + seg_num = 1 while left > 0: - seg = DataSegment.new_from_parts(msg_num, nmessages, + seg = DataSegment.new_from_parts(seg_num, nmessages, msg_seq, master_sha, data[:mtu]) self._outgoing.append(seg) - msg_num = msg_num + 1 + seg_num = seg_num + 1 data = data[mtu:] left = left - mtu self._schedule_send_worker()