Add positive acknowledgements to work around 802.11 + multicast unreliabilities
This commit is contained in:
		
							parent
							
								
									9ef8013a6b
								
							
						
					
					
						commit
						29984ace33
					
				@ -44,6 +44,7 @@ class SegmentBase(object):
 | 
				
			|||||||
	# Message segment packet types
 | 
						# Message segment packet types
 | 
				
			||||||
	_SEGMENT_TYPE_DATA = 0
 | 
						_SEGMENT_TYPE_DATA = 0
 | 
				
			||||||
	_SEGMENT_TYPE_RETRANSMIT = 1
 | 
						_SEGMENT_TYPE_RETRANSMIT = 1
 | 
				
			||||||
 | 
						_SEGMENT_TYPE_ACK = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def magic():
 | 
						def magic():
 | 
				
			||||||
		return SegmentBase._MAGIC
 | 
							return SegmentBase._MAGIC
 | 
				
			||||||
@ -61,6 +62,10 @@ class SegmentBase(object):
 | 
				
			|||||||
		return SegmentBase._SEGMENT_TYPE_RETRANSMIT
 | 
							return SegmentBase._SEGMENT_TYPE_RETRANSMIT
 | 
				
			||||||
	type_retransmit = staticmethod(type_retransmit)
 | 
						type_retransmit = staticmethod(type_retransmit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def type_ack():
 | 
				
			||||||
 | 
							return SegmentBase._SEGMENT_TYPE_ACK
 | 
				
			||||||
 | 
						type_ack = staticmethod(type_ack)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def header_len():
 | 
						def header_len():
 | 
				
			||||||
		"""Return the header size of SegmentBase packets."""
 | 
							"""Return the header size of SegmentBase packets."""
 | 
				
			||||||
		return SegmentBase._HEADER_LEN
 | 
							return SegmentBase._HEADER_LEN
 | 
				
			||||||
@ -156,6 +161,8 @@ class SegmentBase(object):
 | 
				
			|||||||
			segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
 | 
								segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
 | 
				
			||||||
		elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
 | 
							elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
 | 
				
			||||||
			segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
 | 
								segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
 | 
				
			||||||
 | 
							elif set_type == SegmentBase._SEGMENT_TYPE_ACK:
 | 
				
			||||||
 | 
								segment = AckSegment(segno, total_segs, msg_seq_num, master_sha)
 | 
				
			||||||
		else:
 | 
							else:
 | 
				
			||||||
			raise ValueError("Segment has invalid type.")
 | 
								raise ValueError("Segment has invalid type.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -319,6 +326,96 @@ class RetransmitSegment(SegmentBase):
 | 
				
			|||||||
		return self._rt_segment_number	
 | 
							return self._rt_segment_number	
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AckSegment(SegmentBase):
 | 
				
			||||||
 | 
						"""A message segment that encapsulates a message acknowledgement."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Ack data format:
 | 
				
			||||||
 | 
						#  2: acked message sequence number
 | 
				
			||||||
 | 
						# 20: acked message total data sha1
 | 
				
			||||||
 | 
						#  4: acked message source IP address
 | 
				
			||||||
 | 
						_ACK_DATA_TEMPLATE = "! H20sI"
 | 
				
			||||||
 | 
						_ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def data_template():
 | 
				
			||||||
 | 
							return AckSegment._ACK_DATA_TEMPLATE
 | 
				
			||||||
 | 
						data_template = staticmethod(data_template)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def __init__(self, segno, total_segs, msg_seq_num, master_sha):
 | 
				
			||||||
 | 
							"""Should not be called directly."""
 | 
				
			||||||
 | 
							if segno != 1 or total_segs != 1:
 | 
				
			||||||
 | 
								raise ValueError("Acknowledgement messages must have only one segment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
 | 
				
			||||||
 | 
							self._type = SegmentBase._SEGMENT_TYPE_ACK
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr):
 | 
				
			||||||
 | 
							# Sanity checks on the message attributes
 | 
				
			||||||
 | 
							if not ack_msg_seq_num or type(ack_msg_seq_num) != type(1):
 | 
				
			||||||
 | 
								raise ValueError("Ack message sequnce number must be an integer.")
 | 
				
			||||||
 | 
							if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535:
 | 
				
			||||||
 | 
								raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.")
 | 
				
			||||||
 | 
							if not ack_master_sha or type(ack_master_sha) != type("") or len(ack_master_sha) != 20:
 | 
				
			||||||
 | 
								raise ValueError("Ack message SHA1 checksum invalid.")
 | 
				
			||||||
 | 
							if type(ack_addr) != type(""):
 | 
				
			||||||
 | 
								raise ValueError("Ack message invalid address.")
 | 
				
			||||||
 | 
							try:
 | 
				
			||||||
 | 
								foo = socket.inet_aton(ack_addr)
 | 
				
			||||||
 | 
							except socket.error:
 | 
				
			||||||
 | 
								raise ValueError("Ack message invalid address.")
 | 
				
			||||||
 | 
						_verify_data = staticmethod(_verify_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr):
 | 
				
			||||||
 | 
							"""Pack an ack payload."""
 | 
				
			||||||
 | 
							addr_data = socket.inet_aton(ack_addr)
 | 
				
			||||||
 | 
							data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num,
 | 
				
			||||||
 | 
									ack_master_sha, addr_data)
 | 
				
			||||||
 | 
							return (data, _sha_data(data))
 | 
				
			||||||
 | 
						_make_ack_data = staticmethod(_make_ack_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr):
 | 
				
			||||||
 | 
							"""Static constructor for creation from individual attributes."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
 | 
				
			||||||
 | 
							(data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num,
 | 
				
			||||||
 | 
									ack_master_sha, ack_addr)
 | 
				
			||||||
 | 
							segment = AckSegment(1, 1, msg_seq_num, data_sha)
 | 
				
			||||||
 | 
							segment._data_len = AckSegment._ACK_DATA_LEN
 | 
				
			||||||
 | 
							segment._data = data
 | 
				
			||||||
 | 
							SegmentBase._validate_address(addr)
 | 
				
			||||||
 | 
							segment._addr = addr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							segment._ack_msg_seq_num = ack_msg_seq_num
 | 
				
			||||||
 | 
							segment._ack_master_sha = ack_master_sha
 | 
				
			||||||
 | 
							segment._ack_addr = ack_addr
 | 
				
			||||||
 | 
							return segment
 | 
				
			||||||
 | 
						new_from_parts = staticmethod(new_from_parts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _unpack_data(self, stream, data_len):
 | 
				
			||||||
 | 
							if data_len != self._ACK_DATA_LEN:
 | 
				
			||||||
 | 
								raise ValueError("Ack segment data had invalid length.")
 | 
				
			||||||
 | 
							data = stream.read(data_len)
 | 
				
			||||||
 | 
							(ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data)
 | 
				
			||||||
 | 
							try:
 | 
				
			||||||
 | 
								ack_addr = socket.inet_ntoa(ack_addr_data)
 | 
				
			||||||
 | 
							except socket.error:
 | 
				
			||||||
 | 
								raise ValueError("Ack segment data had invalid address.")
 | 
				
			||||||
 | 
							AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							self._data = data
 | 
				
			||||||
 | 
							self._data_len = data_len
 | 
				
			||||||
 | 
							self._ack_msg_seq_num = ack_msg_seq_num
 | 
				
			||||||
 | 
							self._ack_master_sha = ack_master_sha
 | 
				
			||||||
 | 
							self._ack_addr = ack_addr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def ack_msg_seq_num(self):
 | 
				
			||||||
 | 
							return self._ack_msg_seq_num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def ack_master_sha(self):
 | 
				
			||||||
 | 
							return self._ack_master_sha
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def ack_addr(self):
 | 
				
			||||||
 | 
							return self._ack_addr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Message(object):
 | 
					class Message(object):
 | 
				
			||||||
	"""Tracks an entire message object, which is composed of a number
 | 
						"""Tracks an entire message object, which is composed of a number
 | 
				
			||||||
	of individual segments."""
 | 
						of individual segments."""
 | 
				
			||||||
@ -429,6 +526,53 @@ class Message(object):
 | 
				
			|||||||
		return 0
 | 
							return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_local_interfaces():
 | 
				
			||||||
 | 
						import array
 | 
				
			||||||
 | 
						import struct
 | 
				
			||||||
 | 
						import fcntl
 | 
				
			||||||
 | 
						import socket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						max_possible = 4
 | 
				
			||||||
 | 
						bytes = max_possible * 32
 | 
				
			||||||
 | 
						SIOCGIFCONF = 0x8912
 | 
				
			||||||
 | 
						names = array.array('B', '\0' * bytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
				
			||||||
 | 
						ifreq = struct.pack('iL', bytes, names.buffer_info()[0])
 | 
				
			||||||
 | 
						result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq)
 | 
				
			||||||
 | 
						sockfd.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						outbytes = struct.unpack('iL', result)[0]
 | 
				
			||||||
 | 
						namestr = names.tostring()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_local_ip_addresses():
 | 
				
			||||||
 | 
						"""Call Linux specific bits to retrieve our own IP address."""
 | 
				
			||||||
 | 
						import socket
 | 
				
			||||||
 | 
						import sys
 | 
				
			||||||
 | 
						import fcntl
 | 
				
			||||||
 | 
						import struct
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						intfs = _get_local_interfaces()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ips = []
 | 
				
			||||||
 | 
						SIOCGIFADDR = 0x8915
 | 
				
			||||||
 | 
						sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
				
			||||||
 | 
						for intf in intfs:
 | 
				
			||||||
 | 
							if intf == "lo":
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							try:
 | 
				
			||||||
 | 
								ifreq = (intf + '\0'*32)[:32]
 | 
				
			||||||
 | 
								result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq)
 | 
				
			||||||
 | 
								addr = socket.inet_ntoa(result[20:24])
 | 
				
			||||||
 | 
								ips.append(addr)
 | 
				
			||||||
 | 
							except IOError, exc:
 | 
				
			||||||
 | 
								print "Error getting IP address: %s" % exc
 | 
				
			||||||
 | 
						sockfd.close()
 | 
				
			||||||
 | 
						return ips
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MostlyReliablePipe(object):
 | 
					class MostlyReliablePipe(object):
 | 
				
			||||||
	"""Implement Mostly-Reliable UDP.  We don't actually care about guaranteeing
 | 
						"""Implement Mostly-Reliable UDP.  We don't actually care about guaranteeing
 | 
				
			||||||
	delivery or receipt, just a better effort than no effort at all."""
 | 
						delivery or receipt, just a better effort than no effort at all."""
 | 
				
			||||||
@ -446,13 +590,17 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		self._send_worker = 0
 | 
							self._send_worker = 0
 | 
				
			||||||
		self._seq_counter = 0
 | 
							self._seq_counter = 0
 | 
				
			||||||
		self._drop_prob = 0
 | 
							self._drop_prob = 0
 | 
				
			||||||
		self._rt_check_worker = 0
 | 
							self._rt_check_worker_id = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		self._outgoing = []
 | 
							self._outgoing = []
 | 
				
			||||||
		self._sent = {}
 | 
							self._sent = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		self._incoming = {}  # (message sha, # of segments) -> [segment1, segment2, ...]
 | 
							self._incoming = {}  # (message sha, # of segments) -> [segment1, segment2, ...]
 | 
				
			||||||
		self._dispatched = {}
 | 
							self._dispatched = {}
 | 
				
			||||||
 | 
							self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp
 | 
				
			||||||
 | 
							self._ack_check_worker_id = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							self._local_ips = _get_local_ip_addresses()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		self._setup_listener()
 | 
							self._setup_listener()
 | 
				
			||||||
		self._setup_sender()
 | 
							self._setup_sender()
 | 
				
			||||||
@ -461,9 +609,12 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		if self._send_worker > 0:
 | 
							if self._send_worker > 0:
 | 
				
			||||||
			gobject.source_remove(self._send_worker)
 | 
								gobject.source_remove(self._send_worker)
 | 
				
			||||||
			self._send_worker = 0
 | 
								self._send_worker = 0
 | 
				
			||||||
		if self._rt_check_worker > 0:
 | 
							if self._rt_check_worker_id > 0:
 | 
				
			||||||
			gobject.source_remove(self._rt_check_worker)
 | 
								gobject.source_remove(self._rt_check_worker_id)
 | 
				
			||||||
			self._rt_check_worker = 0
 | 
								self._rt_check_worker_id = 0
 | 
				
			||||||
 | 
							if self._ack_check_worker_id > 0:
 | 
				
			||||||
 | 
								gobject.source_remove(self._ack_check_worker_id)
 | 
				
			||||||
 | 
								self._ack_check_worker_id = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def _setup_sender(self):
 | 
						def _setup_sender(self):
 | 
				
			||||||
		"""Setup the send socket for multicast."""
 | 
							"""Setup the send socket for multicast."""
 | 
				
			||||||
@ -495,7 +646,8 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		# Watch the listener socket for data
 | 
							# Watch the listener socket for data
 | 
				
			||||||
		gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
 | 
							gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
 | 
				
			||||||
		gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
 | 
							gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
 | 
				
			||||||
		gobject.timeout_add(50, self._retransmit_check_worker)
 | 
							self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker)
 | 
				
			||||||
 | 
							self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		self._started = True
 | 
							self._started = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -516,12 +668,18 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
			if message.last_incoming_time() < now - self._SEGMENT_TTL:
 | 
								if message.last_incoming_time() < now - self._SEGMENT_TTL:
 | 
				
			||||||
				del self._incoming[msg_key]
 | 
									del self._incoming[msg_key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		# Remove already dispatched messages after a while
 | 
							# Remove already received and dispatched messages after a while
 | 
				
			||||||
		for msg_key in self._dispatched.keys()[:]:
 | 
							for msg_key in self._dispatched.keys()[:]:
 | 
				
			||||||
			message = self._dispatched[msg_key]
 | 
								message = self._dispatched[msg_key]
 | 
				
			||||||
			if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
 | 
								if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
 | 
				
			||||||
				del self._dispatched[msg_key]
 | 
									del self._dispatched[msg_key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							# Remove received acks after a while
 | 
				
			||||||
 | 
							for ack_key in self._acks.keys()[:]:
 | 
				
			||||||
 | 
								ack_time = self._acks[ack_key]
 | 
				
			||||||
 | 
								if ack_time < now - (self._SEGMENT_TTL*2):
 | 
				
			||||||
 | 
									del self._acks[ack_key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return True
 | 
							return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_MAX_SEGMENT_RETRIES = 10
 | 
						_MAX_SEGMENT_RETRIES = 10
 | 
				
			||||||
@ -541,6 +699,8 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		return False
 | 
							return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def _retransmit_check_worker(self):
 | 
						def _retransmit_check_worker(self):
 | 
				
			||||||
 | 
							"""Periodically check for and send retransmit requests for message
 | 
				
			||||||
 | 
							segments that got lost."""
 | 
				
			||||||
		try:
 | 
							try:
 | 
				
			||||||
			now = time.time()
 | 
								now = time.time()
 | 
				
			||||||
			for key in self._incoming.keys()[:]:
 | 
								for key in self._incoming.keys()[:]:
 | 
				
			||||||
@ -582,6 +742,12 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		# First segment in the message
 | 
							# First segment in the message
 | 
				
			||||||
		if not self._incoming.has_key(msg_key):
 | 
							if not self._incoming.has_key(msg_key):
 | 
				
			||||||
			self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
 | 
								self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
 | 
				
			||||||
 | 
								# Acknowledge the message if it didn't come from us
 | 
				
			||||||
 | 
								if addr[0] not in self._local_ips:
 | 
				
			||||||
 | 
									print "Sending ack for msg (%s %s) from %s)" % (msg_seq_num, msg_sha, addr[0])
 | 
				
			||||||
 | 
									ack_key = (msg_seq_num, msg_sha, addr[0])
 | 
				
			||||||
 | 
									if not self._acks.has_key(ack_key):
 | 
				
			||||||
 | 
										self._send_ack_for_message(msg_seq_num, msg_sha, addr[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		message = self._incoming[msg_key]
 | 
							message = self._incoming[msg_key]
 | 
				
			||||||
		# Look for a dupe, and if so, drop the new segment
 | 
							# Look for a dupe, and if so, drop the new segment
 | 
				
			||||||
@ -636,10 +802,74 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
 | 
							next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
 | 
				
			||||||
		self._schedule_segment_retransmit(key, segment, next_transmit, now)
 | 
							self._schedule_segment_retransmit(key, segment, next_transmit, now)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _ack_check_worker(self):
 | 
				
			||||||
 | 
							"""Periodically check for messages that haven't received an ack
 | 
				
			||||||
 | 
							yet, and retransmit them."""
 | 
				
			||||||
 | 
							try:
 | 
				
			||||||
 | 
								now = time.time()
 | 
				
			||||||
 | 
								for key in self._sent.keys()[:]:
 | 
				
			||||||
 | 
									segment = self._sent[key]
 | 
				
			||||||
 | 
									# We only care about retransmitting the first segment
 | 
				
			||||||
 | 
									# of a message, since if other machines don't have the
 | 
				
			||||||
 | 
									# rest of the segments, they'll issue retransmit requests
 | 
				
			||||||
 | 
									if segment.segment_number() != 1:
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									if segment.last_transmit() > now - 0.150: # 150ms
 | 
				
			||||||
 | 
										# Was just retransmitted recently, wait longer
 | 
				
			||||||
 | 
										# before retransmitting it
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									ack_key = None
 | 
				
			||||||
 | 
									for ip in self._local_ips:
 | 
				
			||||||
 | 
										ack_key = (segment.message_sequence_number(), segment.master_sha(), ip)
 | 
				
			||||||
 | 
										if self._acks.has_key(ack_key):
 | 
				
			||||||
 | 
											break
 | 
				
			||||||
 | 
										ack_key = None
 | 
				
			||||||
 | 
									# If the segment already has been acked, don't send it
 | 
				
			||||||
 | 
									# again unless somebody explicitly requests a retransmit
 | 
				
			||||||
 | 
									if ack_key is not None:
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									del self._sent[key]
 | 
				
			||||||
 | 
									self._outgoing.append(segment)
 | 
				
			||||||
 | 
									self._schedule_send_worker()
 | 
				
			||||||
 | 
							except KeyboardInterrupt:
 | 
				
			||||||
 | 
								return False
 | 
				
			||||||
 | 
							return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr):
 | 
				
			||||||
 | 
							"""Send an ack segment for a message."""
 | 
				
			||||||
 | 
							msg_seq_num = self._next_msg_seq()
 | 
				
			||||||
 | 
							ack = AckSegment.new_from_parts(self._remote_addr, msg_seq_num,
 | 
				
			||||||
 | 
								ack_msg_seq_num, ack_msg_sha, ack_addr)
 | 
				
			||||||
 | 
							self._outgoing.append(ack)
 | 
				
			||||||
 | 
							self._schedule_send_worker()
 | 
				
			||||||
 | 
							self._process_incoming_ack(ack)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _process_incoming_ack(self, segment):
 | 
				
			||||||
 | 
							"""Save the ack so that we don't send an ack when we start getting the segments
 | 
				
			||||||
 | 
							the ack was acknowledging."""
 | 
				
			||||||
 | 
							# If the ack is supposed to be for a message we sent, only accept it if
 | 
				
			||||||
 | 
							# we actually sent the message to which it refers
 | 
				
			||||||
 | 
							ack_addr = segment.ack_addr()
 | 
				
			||||||
 | 
							ack_master_sha = segment.ack_master_sha()
 | 
				
			||||||
 | 
							ack_msg_seq_num = segment.ack_msg_seq_num()
 | 
				
			||||||
 | 
							if ack_addr in self._local_ips:
 | 
				
			||||||
 | 
								sent_key = (ack_msg_seq_num, ack_master_sha, 1)
 | 
				
			||||||
 | 
								if not self._sent.has_key(sent_key):
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
							ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr)
 | 
				
			||||||
 | 
							if not self._acks.has_key(ack_key):
 | 
				
			||||||
 | 
								print "Got ack for msg (%s %s) originally from %s" % (ack_msg_seq_num, ack_master_sha, ack_addr)
 | 
				
			||||||
 | 
								self._acks[ack_key] = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def set_drop_probability(self, prob=4):
 | 
						def set_drop_probability(self, prob=4):
 | 
				
			||||||
		"""Debugging function to randomly drop incoming packets.
 | 
							"""Debugging function to randomly drop incoming packets.
 | 
				
			||||||
		The prob argument should be an integer between 1 and 10 to drop,
 | 
							The prob argument should be an integer between 1 and 10 to drop,
 | 
				
			||||||
		or 0 to drop none.  Higher numbers drop more packets."""
 | 
							or 0 to drop none.  Higher numbers drop more packets."""
 | 
				
			||||||
 | 
							if type(prob) != type(1):
 | 
				
			||||||
 | 
								raise ValueError("Drop probability must be an integer.")
 | 
				
			||||||
 | 
							if prob < 1 or prob > 10:
 | 
				
			||||||
 | 
								raise ValueError("Drop probability must be between 1 and 10 inclusive.")
 | 
				
			||||||
		self._drop_prob = prob
 | 
							self._drop_prob = prob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def _handle_incoming_data(self, source, condition):
 | 
						def _handle_incoming_data(self, source, condition):
 | 
				
			||||||
@ -665,6 +895,8 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
					self._process_incoming_data(segment)
 | 
										self._process_incoming_data(segment)
 | 
				
			||||||
				elif stype == SegmentBase.type_retransmit():
 | 
									elif stype == SegmentBase.type_retransmit():
 | 
				
			||||||
					self._process_retransmit_request(segment)
 | 
										self._process_retransmit_request(segment)
 | 
				
			||||||
 | 
									elif stype == SegmentBase.type_ack():
 | 
				
			||||||
 | 
										self._process_incoming_ack(segment)
 | 
				
			||||||
		except ValueError, exc:
 | 
							except ValueError, exc:
 | 
				
			||||||
			print "(MRP): Bad segment: %s" % exc
 | 
								print "(MRP): Bad segment: %s" % exc
 | 
				
			||||||
		return True
 | 
							return True
 | 
				
			||||||
@ -693,12 +925,12 @@ class MostlyReliablePipe(object):
 | 
				
			|||||||
		nmessages = length / mtu
 | 
							nmessages = length / mtu
 | 
				
			||||||
		if length % mtu > 0:
 | 
							if length % mtu > 0:
 | 
				
			||||||
			nmessages = nmessages + 1
 | 
								nmessages = nmessages + 1
 | 
				
			||||||
		msg_num = 1
 | 
							seg_num = 1
 | 
				
			||||||
		while left > 0:
 | 
							while left > 0:
 | 
				
			||||||
			seg = DataSegment.new_from_parts(msg_num, nmessages,
 | 
								seg = DataSegment.new_from_parts(seg_num, nmessages,
 | 
				
			||||||
					msg_seq, master_sha, data[:mtu])
 | 
										msg_seq, master_sha, data[:mtu])
 | 
				
			||||||
			self._outgoing.append(seg)
 | 
								self._outgoing.append(seg)
 | 
				
			||||||
			msg_num = msg_num + 1
 | 
								seg_num = seg_num + 1
 | 
				
			||||||
			data = data[mtu:]
 | 
								data = data[mtu:]
 | 
				
			||||||
			left = left - mtu
 | 
								left = left - mtu
 | 
				
			||||||
		self._schedule_send_worker()
 | 
							self._schedule_send_worker()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user