Do message reassembly
This commit is contained in:
		
							parent
							
								
									e4516c6d81
								
							
						
					
					
						commit
						0f7dc51ac0
					
				| @ -15,90 +15,126 @@ _HEADER_LEN = 30 | ||||
| _MAGIC = 0xbaea4304 | ||||
| _TTL = 120 # 2 minutes | ||||
| 
 | ||||
| def _stringify_sha(sha): | ||||
| 	print_sha = "" | ||||
| 	for char in sha: | ||||
| 		print_sha = print_sha + binascii.b2a_hex(char) | ||||
| 	return print_sha | ||||
| 
 | ||||
| def _sha_data(data): | ||||
| 	sha_hash = sha.new() | ||||
| 	sha_hash.update(data) | ||||
| 	return sha_hash.digest() | ||||
| 
 | ||||
| class MessageSegment(object): | ||||
| 	# 4: magic (0xbaea4304) | ||||
| 	# 2: segment number | ||||
| 	# 2: total segments | ||||
| 	# 2: data size | ||||
| 	#20: total message sha1 | ||||
| 	# 2: message sequence number | ||||
| 	#20: total data sha1 | ||||
| 	_HEADER_TEMPLATE = "! IHHH20s" | ||||
| 
 | ||||
| 	def _new_from_parts(self, num, all, data, master_sha): | ||||
| 	def _new_from_parts(self, msg_seq_num, segno, total_segs, data, master_sha): | ||||
| 		"""Construct a new message segment from individual attributes.""" | ||||
| 		if not data: | ||||
| 			raise ValueError("Must have valid data.") | ||||
| 		if num > 65535: | ||||
| 		if segno > 65535: | ||||
| 			raise ValueError("Segment number cannot be more than 65535.") | ||||
| 		if num < 1: | ||||
| 		if segno < 1: | ||||
| 			raise ValueError("Segment number must be greater than zero.") | ||||
| 		if all > 65535: | ||||
| 		if total_segs > 65535: | ||||
| 			raise ValueError("Message cannot have more than 65535 segments.") | ||||
| 		if all < 1: | ||||
| 		if total_segs < 1: | ||||
| 			raise ValueError("Message must have at least one segment.") | ||||
| 		if msg_seq_num < 1: | ||||
| 			raise ValueError("Message sequence number must be greater than 0.") | ||||
| 		self._stime = time.time() | ||||
| 		self._data = data | ||||
| 		self._data_len = len(data) | ||||
| 		self._master_sha = master_sha | ||||
| 		self._num = num | ||||
| 		self._all = all | ||||
| 		self._segno = segno | ||||
| 		self._total_segs = total_segs | ||||
| 		self._msg_seq_num = msg_seq_num | ||||
| 		self._addr = None | ||||
| 
 | ||||
| 		self._header = struct.pack(self._HEADER_TEMPLATE, _MAGIC, self._num, | ||||
| 				self._all, self._data_len, self._master_sha) | ||||
| 		# Make the header | ||||
| 		self._header = struct.pack(self._HEADER_TEMPLATE, _MAGIC, self._segno, | ||||
| 				self._total_segs, self._msg_seq_num, self._master_sha) | ||||
| 
 | ||||
| 	def _new_from_data(self, data): | ||||
| 	def _new_from_data(self, addr, data): | ||||
| 		"""Verify and construct a new message segment from network data.""" | ||||
| 		if len(data) < _HEADER_LEN + 1: | ||||
| 			raise ValueError("Message is less then minimum required length") | ||||
| 		stream = StringIO.StringIO(data) | ||||
| 		self._stime = None | ||||
| 		(magic, num, all, data_len, master_sha) = struct.unpack(self._HEADER_TEMPLATE, | ||||
| 				stream.read(struct.calcsize(self._HEADER_TEMPLATE))) | ||||
| 		self._addr = addr | ||||
| 
 | ||||
| 		# Format checking | ||||
| 		# Determine and verify the length of included data | ||||
| 		stream.seek(0, 2) | ||||
| 		header_size = struct.calcsize(self._HEADER_TEMPLATE) | ||||
| 		self._data_len = stream.tell() - header_size | ||||
| 		if self._data_len < 1: | ||||
| 			raise ValueError("Message must have some data.") | ||||
| 		if self._data_len > _MTU: | ||||
| 			raise ValueError("Data length must not be larger than the MTU (%s)." % _MTU) | ||||
| 		stream.seek(0) | ||||
| 
 | ||||
| 		# Read the header attributes | ||||
| 		(magic, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(self._HEADER_TEMPLATE, | ||||
| 				stream.read(header_size)) | ||||
| 
 | ||||
| 		# Sanity checks on the message attributes | ||||
| 		if magic != _MAGIC: | ||||
| 			raise ValueError("Message does not have the correct magic.") | ||||
| 		if not num: | ||||
| 		if segno < 1: | ||||
| 			raise ValueError("Segment number must be greater than 0.") | ||||
| 		if not all: | ||||
| 		if segno > total_segs: | ||||
| 			raise ValueError("Segment number cannot be larger than message segment total.") | ||||
| 		if total_segs < 1: | ||||
| 			raise ValueError("Message must have at least one segment.") | ||||
| 		if not data_len: | ||||
| 			raise ValueError("Message must have some data.") | ||||
| 		if data_len > _MTU: | ||||
| 			raise ValueError("Data length must not be larger than the MTU (%s)." % _MTU) | ||||
| 		if msg_seq_num < 1: | ||||
| 			raise ValueError("Message sequence number must be greater than 0.") | ||||
| 
 | ||||
| 		self._num = num | ||||
| 		self._all = all | ||||
| 		self._data_len = data_len | ||||
| 		self._segno = segno | ||||
| 		self._total_segs = total_segs | ||||
| 		self._msg_seq_num = msg_seq_num | ||||
| 		self._master_sha = master_sha | ||||
| 
 | ||||
| 		# Read data | ||||
| 		self._data = struct.unpack("! %ds" % self._data_len, stream.read(self._data_len)) | ||||
| 		# Reconstruct the data | ||||
| 		self._data = struct.unpack("! %ds" % self._data_len, stream.read(self._data_len))[0] | ||||
| 
 | ||||
| 	def new_from_parts(num, all, data, master_sha): | ||||
| 	def new_from_parts(msg_seq_num, segno, total_segs, data, master_sha): | ||||
| 		"""Static constructor for creation from individual attributes.""" | ||||
| 		segment = MessageSegment() | ||||
| 		segment._new_from_parts(num, all, data, master_sha) | ||||
| 		segment._new_from_parts(msg_seq_num, segno, total_segs, data, master_sha) | ||||
| 		return segment | ||||
| 	new_from_parts = staticmethod(new_from_parts) | ||||
| 
 | ||||
| 	def new_from_data(data): | ||||
| 	def new_from_data(addr, data): | ||||
| 		"""Static constructor for creation from a packed data stream.""" | ||||
| 		segment = MessageSegment() | ||||
| 		segment._new_from_data(data) | ||||
| 		segment._new_from_data(addr, data) | ||||
| 		return segment | ||||
| 	new_from_data = staticmethod(new_from_data) | ||||
| 
 | ||||
| 	def stime(self): | ||||
| 		return self._stime | ||||
| 
 | ||||
| 	def num(self): | ||||
| 		return self._num | ||||
| 	def addr(self): | ||||
| 		return self._addr | ||||
| 
 | ||||
| 	def all(self): | ||||
| 		return self._all | ||||
| 	def segment_number(self): | ||||
| 		return self._segno | ||||
| 
 | ||||
| 	def total_segments(self): | ||||
| 		return self._total_segs | ||||
| 
 | ||||
| 	def message_sequence_number(self): | ||||
| 		return self._msg_seq_num | ||||
| 
 | ||||
| 	def data(self): | ||||
| 		return self._data | ||||
| 
 | ||||
| 	def header(self): | ||||
| 		return self._header | ||||
| 
 | ||||
| 	def master_sha(self): | ||||
| 		return self._master_sha | ||||
| 
 | ||||
| @ -117,19 +153,25 @@ class MostlyReliablePipe(object): | ||||
| 		self._data_cb = data_cb | ||||
| 		self._user_data = user_data | ||||
| 		self._started = False | ||||
| 		self._worker = 0 | ||||
| 		self._seq_counter = 0 | ||||
| 
 | ||||
| 		self._outgoing = [] | ||||
| 		self._sent = [] | ||||
| 		self._worker = 0 | ||||
| 
 | ||||
| 		self._incoming = {}  # (message sha, # of segments) -> [segment1, segment2, ...] | ||||
| 
 | ||||
| 		self._setup_listener() | ||||
| 		self._setup_sender() | ||||
| 
 | ||||
| 	def _setup_sender(self): | ||||
| 		"""Setup the send socket for multicast.""" | ||||
| 		self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||
| 		# Make the socket multicast-aware, and set TTL. | ||||
| 		self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit | ||||
| 
 | ||||
| 	def _setup_listener(self): | ||||
| 		"""Set up the listener socket for multicast traffic.""" | ||||
| 		# Listener socket | ||||
| 		self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||
| 
 | ||||
| @ -139,6 +181,7 @@ class MostlyReliablePipe(object): | ||||
| 		self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) | ||||
| 
 | ||||
| 	def start(self): | ||||
| 		"""Let the listener socket start listening for network data.""" | ||||
| 		# Set some more multicast options | ||||
| 		self._listen_sock.bind((self._local_addr, self._port))  # Bind to all interfaces | ||||
| 		self._listen_sock.settimeout(2) | ||||
| @ -155,31 +198,84 @@ class MostlyReliablePipe(object): | ||||
| 		self._started = True | ||||
| 
 | ||||
| 	def _segment_ttl_worker(self): | ||||
| 		"""Cull already-sent message segments that are past their TTL.""" | ||||
| 		now = time.time() | ||||
| 		for segment in self._sent[:]: | ||||
| 			if segment.stime() < now - _MSG_TTL: | ||||
| 				self._sent.remove(segment) | ||||
| 		return True | ||||
| 
 | ||||
| 	def _dispatch_message(self, addr, message): | ||||
| 		"""Send complete message data to the owner's data callback.""" | ||||
| 		self._data_cb(addr, message, self._user_data) | ||||
| 
 | ||||
| 	def _process_incoming(self, segment): | ||||
| 		"""Handle a new message segment.  First checks if there is only one | ||||
| 		segment to the message, and if the checksum from the header matches | ||||
| 		that computed from the data, dispatches it.  Otherwise, it adds the | ||||
| 		new segment to the list of other segments for that message, and  | ||||
| 		checks to see if the message is complete.  If all segments are present, | ||||
| 		the message is reassembled and dispatched.""" | ||||
| 
 | ||||
| 		string_sha = _stringify_sha(segment.master_sha()) | ||||
| 		nsegs = segment.total_segments() | ||||
| 		addr = segment.addr() | ||||
| 		segno = segment.segment_number() | ||||
| 
 | ||||
| 		# Short-circuit single-segment messages | ||||
| 		if segno == 1 and nsegs == 1: | ||||
| 			# Ensure the header's master sha actually equals the data's sha | ||||
| 			if string_sha == _stringify_sha(_sha_data(segment.data())): | ||||
| 				self._dispatch_message(addr, segment.data()) | ||||
| 				return | ||||
| 
 | ||||
| 		# Otherwise, track the new segment | ||||
| 		msg_seq_num = segment.message_sequence_number() | ||||
| 		msg_key = (addr[0], msg_seq_num, string_sha, nsegs) | ||||
| 		if not self._incoming.has_key(msg_key): | ||||
| 			self._incoming[msg_key] = {} | ||||
| 
 | ||||
| 		# Look for a dupe, and if so, drop the new segment | ||||
| 		if self._incoming[msg_key].has_key(segno): | ||||
| 			return | ||||
| 		self._incoming[msg_key][segno] = segment | ||||
| 
 | ||||
| 		# Dispatch the message if all segments are present and the sha is correct | ||||
| 		if len(self._incoming[msg_key]) == nsegs: | ||||
| 			all_data = '' | ||||
| 			for i in range(1, nsegs + 1): | ||||
| 				all_data = all_data + self._incoming[msg_key][i].data() | ||||
| 			if string_sha == _stringify_sha(_sha_data(all_data)): | ||||
| 				self._dispatch_message(addr, all_data) | ||||
| 			del self._incoming[msg_key] | ||||
| 
 | ||||
| 	def _handle_incoming_data(self, source, condition): | ||||
| 		"""Handle incoming network data by making a message segment out of it | ||||
| 		sending it off to the processing function.""" | ||||
| 		if not (condition & gobject.IO_IN): | ||||
| 			return True | ||||
| 		msg = {} | ||||
| 		data, addr = source.recvfrom(_MTU + _HEADER_LEN) | ||||
| 		if self._data_cb: | ||||
| 			self._data_cb(addr, data, self._user_data) | ||||
| 		try: | ||||
| 			segment = MessageSegment.new_from_data(addr, data) | ||||
| 			self._process_incoming(segment) | ||||
| 		except ValueError, exc: | ||||
| 			pass | ||||
| 		return True | ||||
| 
 | ||||
| 	def send(self, data): | ||||
| 		"""Break data up into chunks and queue for later transmission.""" | ||||
| 		if not self._started: | ||||
| 			raise Exception("Can't send anything until started!") | ||||
| 
 | ||||
| 		self._seq_counter = self._seq_counter + 1 | ||||
| 		if self._seq_counter > 65535: | ||||
| 			self._seq_counter = 1 | ||||
| 
 | ||||
| 		# Pack the data into network byte order | ||||
| 		template = "! %ds" % len(data) | ||||
| 		data = struct.pack(template, data) | ||||
| 		sha_hash = sha.new() | ||||
| 		sha_hash.update(data) | ||||
| 		master_sha = sha_hash.digest() | ||||
| 		master_sha = _sha_data(data) | ||||
| 
 | ||||
| 		# Split up the data into segments | ||||
| 		left = length = len(data) | ||||
| @ -188,7 +284,8 @@ class MostlyReliablePipe(object): | ||||
| 			nmessages = nmessages + 1 | ||||
| 		msg_num = 1 | ||||
| 		while left > 0: | ||||
| 			msg = MessageSegment.new_from_parts(msg_num, nmessages, data[:_MTU], master_sha) | ||||
| 			msg = MessageSegment.new_from_parts(self._seq_counter, msg_num, | ||||
| 					nmessages, data[:_MTU], master_sha) | ||||
| 			self._outgoing.append(msg) | ||||
| 			msg_num = msg_num + 1 | ||||
| 			data = data[_MTU:] | ||||
| @ -197,6 +294,7 @@ class MostlyReliablePipe(object): | ||||
| 			self._worker = gobject.idle_add(self._send_worker) | ||||
| 
 | ||||
| 	def _send_worker(self): | ||||
| 		"""Send all queued segments that have yet to be transmitted.""" | ||||
| 		self._worker = 0 | ||||
| 		for segment in self._outgoing: | ||||
| 			data = segment.segment() | ||||
| @ -207,14 +305,7 @@ class MostlyReliablePipe(object): | ||||
| 
 | ||||
| 
 | ||||
| def got_data(addr, data, user_data=None): | ||||
| 	segment = MessageSegment.new_from_data(data) | ||||
| 	print "Segment (%d/%d)" % (segment.num(), segment.all()) | ||||
| 	print_sha = "" | ||||
| 	for char in segment.master_sha(): | ||||
| 		print_sha = print_sha + binascii.b2a_hex(char) | ||||
| 	print "   Master SHA: %s" % print_sha | ||||
| 	print "   Data: '%s'" % segment.data() | ||||
| 	print "" | ||||
| 	print "Data (%s): %s" % (addr, data) | ||||
| 
 | ||||
| def main(): | ||||
| 	pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Dan Williams
						Dan Williams