From f751407d50cb55d85722c889d42e216eaeb7ee94 Mon Sep 17 00:00:00 2001 From: Dan Williams Date: Wed, 17 May 2006 13:11:48 -0400 Subject: [PATCH] Refactor the segment class into two type-specific classes and base class --- sugar/p2p/MostlyReliablePipe.py | 333 +++++++++++++++++++++----------- 1 file changed, 222 insertions(+), 111 deletions(-) diff --git a/sugar/p2p/MostlyReliablePipe.py b/sugar/p2p/MostlyReliablePipe.py index c57b022e..9009d1a3 100644 --- a/sugar/p2p/MostlyReliablePipe.py +++ b/sugar/p2p/MostlyReliablePipe.py @@ -4,6 +4,7 @@ import sha import struct import StringIO import binascii +import random import pygtk pygtk.require('2.0') @@ -21,7 +22,8 @@ def _sha_data(data): sha_hash.update(data) return sha_hash.digest() -class MessageSegment(object): + +class SegmentBase(object): _MAGIC = 0xbaea4304 # 4: magic (0xbaea4304) @@ -38,119 +40,91 @@ class MessageSegment(object): _SEGMENT_TYPE_DATA = 0 _SEGMENT_TYPE_RETRANSMIT = 1 - def is_data_type(stype): - if stype == MessageSegment._SEGMENT_TYPE_DATA: - return True - return False - is_data_type = staticmethod(is_data_type) + def type_data(): + return SegmentBase._SEGMENT_TYPE_DATA + type_data = staticmethod(type_data) - def is_retransmit_type(stype): - if stype == MessageSegment._SEGMENT_TYPE_RETRANSMIT: - return True - return False - is_retransmit_type = staticmethod(is_retransmit_type) + def type_retransmit(): + return SegmentBase._SEGMENT_TYPE_RETRANSMIT + type_retransmit = staticmethod(type_retransmit) def header_len(): - return MessageSegment._HEADER_LEN + """Return the header size of SegmentBase packets.""" + return SegmentBase._HEADER_LEN header_len = staticmethod(header_len) def mtu(): - return MessageSegment._MTU + """Return the SegmentBase packet MTU.""" + return SegmentBase._MTU mtu = staticmethod(mtu) - def is_type_valid(stype): - if MessageSegment.is_data_type(stype) or MessageSegment.is_retransmit_type(stype): - return True - return False - is_type_valid = staticmethod(is_type_valid) - - def _new_from_parts(self, msg_seq_num, segno, total_segs, data, master_sha): - """Construct a new message segment from individual attributes.""" - if not data: - raise ValueError("Must have valid data.") - if segno > 65535: - raise ValueError("Segment number cannot be more than 65535.") - if segno < 1: - raise ValueError("Segment number must be greater than zero.") - if total_segs > 65535: - raise ValueError("Message cannot have more than 65535 segments.") - if total_segs < 1: - raise ValueError("Message must have at least one segment.") - if msg_seq_num < 1: - raise ValueError("Message sequence number must be greater than 0.") + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + self._type = None + self._transmits = 0 + self._last_transmit = 0 + self._data = None + self._data_len = 0 + self.userdata = None self._stime = time.time() - self._data = data - self._data_len = len(data) - self._master_sha = master_sha - self._segno = segno - self._total_segs = total_segs - self._msg_seq_num = msg_seq_num self._addr = None - self._type = MessageSegment._SEGMENT_TYPE_DATA - - # Make the header - self._header = struct.pack(self._HEADER_TEMPLATE, self._MAGIC, self._type, - self._segno, self._total_segs, self._msg_seq_num, self._master_sha) - - def _new_from_data(self, addr, data): - """Verify and construct a new message segment from network data.""" - if len(data) < self._HEADER_LEN + 1: - raise ValueError("Segment is less then minimum required length") - stream = StringIO.StringIO(data) - self._stime = None - self._addr = addr - - # Determine and verify the length of included data - stream.seek(0, 2) - header_size = struct.calcsize(self._HEADER_TEMPLATE) - self._data_len = stream.tell() - header_size - stream.seek(0) - - # Read the header attributes - (magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(self._HEADER_TEMPLATE, - stream.read(header_size)) # Sanity checks on the message attributes - if not MessageSegment.is_type_valid(seg_type): - raise ValueError("Segment has invalid type.") - if MessageSegment.is_data_type(seg_type): - if segno != 1 or total_segs != 1: - raise ValueError("Retransmission request messages must have only one segment.") - if magic != self._MAGIC: - raise ValueError("Segment does not have the correct magic.") - if self._data_len < 1: - raise ValueError("Segment must have some data.") - if self._data_len > self._MTU: - raise ValueError("Data length must not be larger than the MTU (%s)." % self._MTU) + if segno > 65535: + raise ValueError("Segment number cannot be more than 65535.") if segno < 1: raise ValueError("Segment number must be greater than 0.") if segno > total_segs: raise ValueError("Segment number cannot be larger than message segment total.") if total_segs < 1: raise ValueError("Message must have at least one segment.") + if total_segs > 65535: + raise ValueError("Message cannot have more than 65535 segments.") if msg_seq_num < 1: raise ValueError("Message sequence number must be greater than 0.") - self._type = seg_type self._segno = segno self._total_segs = total_segs self._msg_seq_num = msg_seq_num self._master_sha = master_sha - # Reconstruct the data - self._data = struct.unpack("! %ds" % self._data_len, stream.read(self._data_len))[0] - - def new_from_parts(msg_seq_num, segno, total_segs, data, master_sha): - """Static constructor for creation from individual attributes.""" - segment = MessageSegment() - segment._new_from_parts(msg_seq_num, segno, total_segs, data, master_sha) - return segment - new_from_parts = staticmethod(new_from_parts) - def new_from_data(addr, data): """Static constructor for creation from a packed data stream.""" - segment = MessageSegment() - segment._new_from_data(addr, data) + + # Verify minimum length + if len(data) < SegmentBase.header_len() + 1: + raise ValueError("Segment is less then minimum required length") + stream = StringIO.StringIO(data) + + # Determine and verify the length of included data + stream.seek(0, 2) + header_size = struct.calcsize(SegmentBase._HEADER_TEMPLATE) + data_len = stream.tell() - header_size + stream.seek(0) + + if data_len < 1: + raise ValueError("Segment must have some data.") + if data_len > SegmentBase._MTU: + raise ValueError("Data length must not be larger than the MTU (%s)." % SegmentBase._MTU) + + # Read the first header attributes + (magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(SegmentBase._HEADER_TEMPLATE, + stream.read(header_size)) + + # Sanity checks on the message attributes + if magic != SegmentBase._MAGIC: + raise ValueError("Segment does not have the correct magic.") + + if seg_type == SegmentBase._SEGMENT_TYPE_DATA: + segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) + elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT: + segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) + else: + raise ValueError("Segment has invalid type.") + + # Segment specific data interpretation + segment._addr = addr + segment._unpack_data(stream, data_len) + return segment new_from_data = staticmethod(new_from_data) @@ -178,15 +152,112 @@ class MessageSegment(object): def segment_type(self): return self._type - def segment(self): + def packetize(self): """Return a correctly formatted message that can be immediately sent.""" - return self._header + self._data + header = struct.pack(self._HEADER_TEMPLATE, self._MAGIC, self._type, + self._segno, self._total_segs, self._msg_seq_num, self._master_sha) + return header + self._data + + def transmits(self): + return self._transmits + + def inc_transmits(self): + self._transmits = self._transmits + 1 + self._last_transmit = time.time() + + def last_transmit(self): + return self._last_transmit + +class DataSegment(SegmentBase): + """A message segment that encapsulates random data.""" + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) + self._type = SegmentBase._SEGMENT_TYPE_DATA + + def _get_template_for_len(length): + return "! %ds" % length + _get_template_for_len = staticmethod(_get_template_for_len) + + def _unpack_data(self, stream, data_len): + """Unpack the data stream, called by constructor.""" + self._data_len = data_len + template = DataSegment._get_template_for_len(self._data_len) + self._data = struct.unpack(template, stream.read(self._data_len))[0] + + def new_from_parts(segno, total_segs, msg_seq_num, master_sha, data): + """Construct a new message segment from individual attributes.""" + if not data: + raise ValueError("Must have valid data.") + segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) + segment._data_len = len(data) + template = DataSegment._get_template_for_len(segment._data_len) + segment._data = struct.pack(template, data) + return segment + new_from_parts = staticmethod(new_from_parts) + + +class RetransmitSegment(SegmentBase): + """A message segment that encapsulates a retransmission request.""" + + # Retransmission data format: + # 2: message sequence number + # 20: total data sha1 + # 2: segment number + _RT_DATA_TEMPLATE = "@ H20sH" + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + if segno != 1 or total_segs != 1: + raise ValueError("Retransmission request messages must have only one segment.") + + SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) + self._type = SegmentBase._SEGMENT_TYPE_DATA + + def _make_rtms_data(rt_msg_seq_num, rt_master_sha, rt_segment_number): + """Pack retransmission request payload.""" + data = struct.pack(RetransmitSegment._RT_DATA_TEMPLATE, rt_msg_seq_num, + rt_master_sha, rt_segment_number) + return (data, _sha_data(data), struct.calcsize(RetransmitSegment._RT_DATA_TEMPLATE)) + _make_rtms_data = staticmethod(_make_rtms_data) + + def new_from_parts(msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number): + """Static constructor for creation from individual attributes.""" + (data, data_sha, data_len) = segment._make_rtms_data() + segment = RetransmitSegment(1, 1, msg_seq_num, data_sha) + segment._data_len = data_len + segment._data = data + + segment._rt_msg_seq_num = rt_msg_seq_num + segment._rt_master_sha = rt_master_sha + segment._rt_segment_number = rt_segment_number + return segment + new_from_parts = staticmethod(new_from_parts) + + def _unpack_data(self, stream, data_len): + if data_len != struct.calcsize(self._RT_DATA_TEMPLATE): + raise ValueError("Retransmission request data had invalid length.") + self._data_len = data_len + (rt_msg_seq_num, rt_master_sha, rt_seg_no) = struct.unpack(self._RT_DATA_TEMPLATE, + stream.read(self._data_len)) + self._rt_msg_seq_num = rt_msg_seq_num + self._rt_master_sha = rt_master_sha + self._rt_segment_number = rt_seg_no + + def rt_msg_seq_num(self): + return self._rt_msg_seq_num + + def rt_master_sha(self): + return self._rt_master_sha + + def rt_segment_number(self): + return self._rt_segment_number + class MostlyReliablePipe(object): """Implement Mostly-Reliable UDP. We don't actually care about guaranteeing delivery or receipt, just a better effort than no effort at all.""" - _UDP_MSG_SIZE = MessageSegment.mtu() + MessageSegment.header_len() + _UDP_MSG_SIZE = SegmentBase.mtu() + SegmentBase.header_len() _SEGMENT_TTL = 120 # 2 minutes def __init__(self, local_addr, remote_addr, port, data_cb, user_data=None): @@ -200,7 +271,7 @@ class MostlyReliablePipe(object): self._seq_counter = 0 self._outgoing = [] - self._sent = [] + self._sent = {} self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...] @@ -292,18 +363,49 @@ class MostlyReliablePipe(object): self._dispatch_message(addr, all_data) del self._incoming[msg_key] + _STD_RETRANSMIT_INTERVAL = 500 # 1/2 second (in milliseconds) + def _calc_next_retransmit(self, segment, now): + """Calculate the next time (in seconds) that a packet can be retransmitted.""" + num_retrans = segment.transmits() - 1 + interval = num_retrans * self._STD_RETRANSMIT_INTERVAL + randomness = num_retrans * random.randint(-4, 11) + real_interval = max(self._STD_RETRANSMIT_INTERVAL, interval + randomness) + return max(now, segment.last_transmit() + (real_interval * .001)) + + def _segment_retransmit_cb(self, segment): + """Add a segment ot the outgoing queue and schedule its transmission.""" + del self._sent[key] + self._outgoing.append(segment) + self._schedule_send_worker() + return False + + def _schedule_segment_retransmit(self, segment, when): + """Schedule retransmission of a segment if one is not already scheduled.""" + if segment.userdata: + # Already scheduled for retransmit + return + + if when == 0: + # Immediate retransmission + segment.userdata = gobject.idle_add(self._segment_retransmit_cb, segment) + else: + # convert time to milliseconds + timeout = int((when - time.time()) * 1000) + segment.userdata = gobject.timeout_add(timeout, self._segment_retransmit_cb, + segment) + def _process_retransmit_request(self, segment): """Validate and process a retransmission request.""" - # Retransmission data format: - # 2: message sequence number - # 20: total data sha1 - # 2: segment number - data = segment.data() - template = "@ H20sH" - if len(data) != struct.calcsize(template): - print "Bad retransmission request message format." - # Native byte-order since the receive bits already unpacked it for us - (msg_seq_num, master_sha, segno) = struct.unpack(template, data) + key = (segment.rt_msg_seq_num(), segment.rt_master_sha(), segment.rt_segment_number()) + if not self._sent.has_key(key): + # Either we don't know about the segment, or it was already culled + return + + # Calculate next retransmission time and schedule packet for retransmit + segment = self._sent[key] + now = time.time() + next_retrans = self._calc_next_retransmit(segment, now) + self._schedule_segment_retransmit(segment, next_retrans - now) def _handle_incoming_data(self, source, condition): """Handle incoming network data by making a message segment out of it @@ -313,11 +415,11 @@ class MostlyReliablePipe(object): msg = {} data, addr = source.recvfrom(self._UDP_MSG_SIZE) try: - segment = MessageSegment.new_from_data(addr, data) + segment = SegmentBase.new_from_data(addr, data) stype = segment.segment_type() - if MessageSegment.is_data_type(stype): + if stype == SegmentBase.type_data(): self._process_incoming_data(segment) - elif MessageSegment.is_retransmit_type(stype): + elif stype == SegmentBase.type_retransmit(): self._process_retransmit_request(segment) except ValueError, exc: print "Bad segment: %s" % exc @@ -339,18 +441,21 @@ class MostlyReliablePipe(object): # Split up the data into segments left = length = len(data) - mtu = MessageSegment.mtu() + mtu = SegmentBase.mtu() nmessages = length / mtu if length % mtu > 0: nmessages = nmessages + 1 msg_num = 1 while left > 0: - msg = MessageSegment.new_from_parts(self._seq_counter, msg_num, - nmessages, data[:mtu], master_sha) - self._outgoing.append(msg) + seg = DataSegment.new_from_parts(msg_num, nmessages, + self._seq_counter, master_sha, data[:mtu]) + self._outgoing.append(seg) msg_num = msg_num + 1 data = data[mtu:] left = left - mtu + self._schedule_send_worker() + + def _schedule_send_worker(self): if len(self._outgoing) > 0 and self._worker == 0: self._worker = gobject.idle_add(self._send_worker) @@ -358,9 +463,12 @@ class MostlyReliablePipe(object): """Send all queued segments that have yet to be transmitted.""" self._worker = 0 for segment in self._outgoing: - data = segment.segment() - self._send_sock.sendto(data, (self._remote_addr, self._port)) - self._sent = self._outgoing + packet = segment.packetize() + segment.inc_transmits() + self._send_sock.sendto(packet, (self._remote_addr, self._port)) + segment.userdata = None # Retransmission GSource + key = (segment.message_sequence_number(), segment.master_sha(), segment.segment_number()) + self._sent[key] = segment self._outgoing = [] return False @@ -372,7 +480,10 @@ def main(): pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data) pipe.start() pipe.send('The quick brown fox jumps over the lazy dog') - gtk.main() + try: + gtk.main() + except KeyboardInterrupt: + pass if __name__ == "__main__":