Do message reassembly

This commit is contained in:
Dan Williams 2006-05-16 16:26:23 -04:00
parent e4516c6d81
commit 0f7dc51ac0

View File

@ -15,90 +15,126 @@ _HEADER_LEN = 30
_MAGIC = 0xbaea4304 _MAGIC = 0xbaea4304
_TTL = 120 # 2 minutes _TTL = 120 # 2 minutes
def _stringify_sha(sha):
print_sha = ""
for char in sha:
print_sha = print_sha + binascii.b2a_hex(char)
return print_sha
def _sha_data(data):
sha_hash = sha.new()
sha_hash.update(data)
return sha_hash.digest()
class MessageSegment(object): class MessageSegment(object):
# 4: magic (0xbaea4304) # 4: magic (0xbaea4304)
# 2: segment number # 2: segment number
# 2: total segments # 2: total segments
# 2: data size # 2: message sequence number
#20: total message sha1 #20: total data sha1
_HEADER_TEMPLATE = "! IHHH20s" _HEADER_TEMPLATE = "! IHHH20s"
def _new_from_parts(self, num, all, data, master_sha): def _new_from_parts(self, msg_seq_num, segno, total_segs, data, master_sha):
"""Construct a new message segment from individual attributes."""
if not data: if not data:
raise ValueError("Must have valid data.") raise ValueError("Must have valid data.")
if num > 65535: if segno > 65535:
raise ValueError("Segment number cannot be more than 65535.") raise ValueError("Segment number cannot be more than 65535.")
if num < 1: if segno < 1:
raise ValueError("Segment number must be greater than zero.") raise ValueError("Segment number must be greater than zero.")
if all > 65535: if total_segs > 65535:
raise ValueError("Message cannot have more than 65535 segments.") raise ValueError("Message cannot have more than 65535 segments.")
if all < 1: if total_segs < 1:
raise ValueError("Message must have at least one segment.") raise ValueError("Message must have at least one segment.")
if msg_seq_num < 1:
raise ValueError("Message sequence number must be greater than 0.")
self._stime = time.time() self._stime = time.time()
self._data = data self._data = data
self._data_len = len(data) self._data_len = len(data)
self._master_sha = master_sha self._master_sha = master_sha
self._num = num self._segno = segno
self._all = all self._total_segs = total_segs
self._msg_seq_num = msg_seq_num
self._addr = None
self._header = struct.pack(self._HEADER_TEMPLATE, _MAGIC, self._num, # Make the header
self._all, self._data_len, self._master_sha) self._header = struct.pack(self._HEADER_TEMPLATE, _MAGIC, self._segno,
self._total_segs, self._msg_seq_num, self._master_sha)
def _new_from_data(self, data): def _new_from_data(self, addr, data):
"""Verify and construct a new message segment from network data."""
if len(data) < _HEADER_LEN + 1: if len(data) < _HEADER_LEN + 1:
raise ValueError("Message is less then minimum required length") raise ValueError("Message is less then minimum required length")
stream = StringIO.StringIO(data) stream = StringIO.StringIO(data)
self._stime = None self._stime = None
(magic, num, all, data_len, master_sha) = struct.unpack(self._HEADER_TEMPLATE, self._addr = addr
stream.read(struct.calcsize(self._HEADER_TEMPLATE)))
# Format checking # Determine and verify the length of included data
stream.seek(0, 2)
header_size = struct.calcsize(self._HEADER_TEMPLATE)
self._data_len = stream.tell() - header_size
if self._data_len < 1:
raise ValueError("Message must have some data.")
if self._data_len > _MTU:
raise ValueError("Data length must not be larger than the MTU (%s)." % _MTU)
stream.seek(0)
# Read the header attributes
(magic, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(self._HEADER_TEMPLATE,
stream.read(header_size))
# Sanity checks on the message attributes
if magic != _MAGIC: if magic != _MAGIC:
raise ValueError("Message does not have the correct magic.") raise ValueError("Message does not have the correct magic.")
if not num: if segno < 1:
raise ValueError("Segment number must be greater than 0.") raise ValueError("Segment number must be greater than 0.")
if not all: if segno > total_segs:
raise ValueError("Segment number cannot be larger than message segment total.")
if total_segs < 1:
raise ValueError("Message must have at least one segment.") raise ValueError("Message must have at least one segment.")
if not data_len: if msg_seq_num < 1:
raise ValueError("Message must have some data.") raise ValueError("Message sequence number must be greater than 0.")
if data_len > _MTU:
raise ValueError("Data length must not be larger than the MTU (%s)." % _MTU)
self._num = num self._segno = segno
self._all = all self._total_segs = total_segs
self._data_len = data_len self._msg_seq_num = msg_seq_num
self._master_sha = master_sha self._master_sha = master_sha
# Read data # Reconstruct the data
self._data = struct.unpack("! %ds" % self._data_len, stream.read(self._data_len)) self._data = struct.unpack("! %ds" % self._data_len, stream.read(self._data_len))[0]
def new_from_parts(num, all, data, master_sha): def new_from_parts(msg_seq_num, segno, total_segs, data, master_sha):
"""Static constructor for creation from individual attributes."""
segment = MessageSegment() segment = MessageSegment()
segment._new_from_parts(num, all, data, master_sha) segment._new_from_parts(msg_seq_num, segno, total_segs, data, master_sha)
return segment return segment
new_from_parts = staticmethod(new_from_parts) new_from_parts = staticmethod(new_from_parts)
def new_from_data(data): def new_from_data(addr, data):
"""Static constructor for creation from a packed data stream."""
segment = MessageSegment() segment = MessageSegment()
segment._new_from_data(data) segment._new_from_data(addr, data)
return segment return segment
new_from_data = staticmethod(new_from_data) new_from_data = staticmethod(new_from_data)
def stime(self): def stime(self):
return self._stime return self._stime
def num(self): def addr(self):
return self._num return self._addr
def all(self): def segment_number(self):
return self._all return self._segno
def total_segments(self):
return self._total_segs
def message_sequence_number(self):
return self._msg_seq_num
def data(self): def data(self):
return self._data return self._data
def header(self):
return self._header
def master_sha(self): def master_sha(self):
return self._master_sha return self._master_sha
@ -117,19 +153,25 @@ class MostlyReliablePipe(object):
self._data_cb = data_cb self._data_cb = data_cb
self._user_data = user_data self._user_data = user_data
self._started = False self._started = False
self._worker = 0
self._seq_counter = 0
self._outgoing = [] self._outgoing = []
self._sent = [] self._sent = []
self._worker = 0
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
self._setup_listener() self._setup_listener()
self._setup_sender() self._setup_sender()
def _setup_sender(self): def _setup_sender(self):
"""Setup the send socket for multicast."""
self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Make the socket multicast-aware, and set TTL. # Make the socket multicast-aware, and set TTL.
self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit
def _setup_listener(self): def _setup_listener(self):
"""Set up the listener socket for multicast traffic."""
# Listener socket # Listener socket
self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -139,6 +181,7 @@ class MostlyReliablePipe(object):
self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)
def start(self): def start(self):
"""Let the listener socket start listening for network data."""
# Set some more multicast options # Set some more multicast options
self._listen_sock.bind((self._local_addr, self._port)) # Bind to all interfaces self._listen_sock.bind((self._local_addr, self._port)) # Bind to all interfaces
self._listen_sock.settimeout(2) self._listen_sock.settimeout(2)
@ -155,31 +198,84 @@ class MostlyReliablePipe(object):
self._started = True self._started = True
def _segment_ttl_worker(self): def _segment_ttl_worker(self):
"""Cull already-sent message segments that are past their TTL."""
now = time.time() now = time.time()
for segment in self._sent[:]: for segment in self._sent[:]:
if segment.stime() < now - _MSG_TTL: if segment.stime() < now - _MSG_TTL:
self._sent.remove(segment) self._sent.remove(segment)
return True return True
def _dispatch_message(self, addr, message):
"""Send complete message data to the owner's data callback."""
self._data_cb(addr, message, self._user_data)
def _process_incoming(self, segment):
"""Handle a new message segment. First checks if there is only one
segment to the message, and if the checksum from the header matches
that computed from the data, dispatches it. Otherwise, it adds the
new segment to the list of other segments for that message, and
checks to see if the message is complete. If all segments are present,
the message is reassembled and dispatched."""
string_sha = _stringify_sha(segment.master_sha())
nsegs = segment.total_segments()
addr = segment.addr()
segno = segment.segment_number()
# Short-circuit single-segment messages
if segno == 1 and nsegs == 1:
# Ensure the header's master sha actually equals the data's sha
if string_sha == _stringify_sha(_sha_data(segment.data())):
self._dispatch_message(addr, segment.data())
return
# Otherwise, track the new segment
msg_seq_num = segment.message_sequence_number()
msg_key = (addr[0], msg_seq_num, string_sha, nsegs)
if not self._incoming.has_key(msg_key):
self._incoming[msg_key] = {}
# Look for a dupe, and if so, drop the new segment
if self._incoming[msg_key].has_key(segno):
return
self._incoming[msg_key][segno] = segment
# Dispatch the message if all segments are present and the sha is correct
if len(self._incoming[msg_key]) == nsegs:
all_data = ''
for i in range(1, nsegs + 1):
all_data = all_data + self._incoming[msg_key][i].data()
if string_sha == _stringify_sha(_sha_data(all_data)):
self._dispatch_message(addr, all_data)
del self._incoming[msg_key]
def _handle_incoming_data(self, source, condition): def _handle_incoming_data(self, source, condition):
"""Handle incoming network data by making a message segment out of it
sending it off to the processing function."""
if not (condition & gobject.IO_IN): if not (condition & gobject.IO_IN):
return True return True
msg = {} msg = {}
data, addr = source.recvfrom(_MTU + _HEADER_LEN) data, addr = source.recvfrom(_MTU + _HEADER_LEN)
if self._data_cb: try:
self._data_cb(addr, data, self._user_data) segment = MessageSegment.new_from_data(addr, data)
self._process_incoming(segment)
except ValueError, exc:
pass
return True return True
def send(self, data): def send(self, data):
"""Break data up into chunks and queue for later transmission."""
if not self._started: if not self._started:
raise Exception("Can't send anything until started!") raise Exception("Can't send anything until started!")
self._seq_counter = self._seq_counter + 1
if self._seq_counter > 65535:
self._seq_counter = 1
# Pack the data into network byte order # Pack the data into network byte order
template = "! %ds" % len(data) template = "! %ds" % len(data)
data = struct.pack(template, data) data = struct.pack(template, data)
sha_hash = sha.new() master_sha = _sha_data(data)
sha_hash.update(data)
master_sha = sha_hash.digest()
# Split up the data into segments # Split up the data into segments
left = length = len(data) left = length = len(data)
@ -188,7 +284,8 @@ class MostlyReliablePipe(object):
nmessages = nmessages + 1 nmessages = nmessages + 1
msg_num = 1 msg_num = 1
while left > 0: while left > 0:
msg = MessageSegment.new_from_parts(msg_num, nmessages, data[:_MTU], master_sha) msg = MessageSegment.new_from_parts(self._seq_counter, msg_num,
nmessages, data[:_MTU], master_sha)
self._outgoing.append(msg) self._outgoing.append(msg)
msg_num = msg_num + 1 msg_num = msg_num + 1
data = data[_MTU:] data = data[_MTU:]
@ -197,6 +294,7 @@ class MostlyReliablePipe(object):
self._worker = gobject.idle_add(self._send_worker) self._worker = gobject.idle_add(self._send_worker)
def _send_worker(self): def _send_worker(self):
"""Send all queued segments that have yet to be transmitted."""
self._worker = 0 self._worker = 0
for segment in self._outgoing: for segment in self._outgoing:
data = segment.segment() data = segment.segment()
@ -207,14 +305,7 @@ class MostlyReliablePipe(object):
def got_data(addr, data, user_data=None): def got_data(addr, data, user_data=None):
segment = MessageSegment.new_from_data(data) print "Data (%s): %s" % (addr, data)
print "Segment (%d/%d)" % (segment.num(), segment.all())
print_sha = ""
for char in segment.master_sha():
print_sha = print_sha + binascii.b2a_hex(char)
print " Master SHA: %s" % print_sha
print " Data: '%s'" % segment.data()
print ""
def main(): def main():
pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data) pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data)