Refactor global vars into appropriate classes

This commit is contained in:
Dan Williams 2006-05-17 10:30:11 -04:00
parent 6d7e1dcb4c
commit 093667d253

View File

@ -10,15 +10,6 @@ pygtk.require('2.0')
import gtk, gobject import gtk, gobject
_MTU = 481
_HEADER_LEN = 31
_MAGIC = 0xbaea4304
_TTL = 120 # 2 minutes
# Message segment packet types
_SEGMENT_TYPE_DATA = 0
_SEGMENT_TYPE_RETRANSMIT = 1
def _stringify_sha(sha): def _stringify_sha(sha):
print_sha = "" print_sha = ""
for char in sha: for char in sha:
@ -31,6 +22,8 @@ def _sha_data(data):
return sha_hash.digest() return sha_hash.digest()
class MessageSegment(object): class MessageSegment(object):
_MAGIC = 0xbaea4304
# 4: magic (0xbaea4304) # 4: magic (0xbaea4304)
# 1: type # 1: type
# 2: segment number # 2: segment number
@ -38,6 +31,38 @@ class MessageSegment(object):
# 2: message sequence number # 2: message sequence number
#20: total data sha1 #20: total data sha1
_HEADER_TEMPLATE = "! IbHHH20s" _HEADER_TEMPLATE = "! IbHHH20s"
_HEADER_LEN = struct.calcsize(_HEADER_TEMPLATE)
_MTU = 512 - _HEADER_LEN
# Message segment packet types
_SEGMENT_TYPE_DATA = 0
_SEGMENT_TYPE_RETRANSMIT = 1
def is_data_type(stype):
if stype == MessageSegment._SEGMENT_TYPE_DATA:
return True
return False
is_data_type = staticmethod(is_data_type)
def is_retransmit_type(stype):
if stype == MessageSegment._SEGMENT_TYPE_RETRANSMIT:
return True
return False
is_retransmit_type = staticmethod(is_retransmit_type)
def header_len():
return MessageSegment._HEADER_LEN
header_len = staticmethod(header_len)
def mtu():
return MessageSegment._MTU
mtu = staticmethod(mtu)
def is_type_valid(stype):
if MessageSegment.is_data_type(stype) or MessageSegment.is_retransmit_type(stype):
return True
return False
is_type_valid = staticmethod(is_type_valid)
def _new_from_parts(self, msg_seq_num, segno, total_segs, data, master_sha): def _new_from_parts(self, msg_seq_num, segno, total_segs, data, master_sha):
"""Construct a new message segment from individual attributes.""" """Construct a new message segment from individual attributes."""
@ -61,15 +86,15 @@ class MessageSegment(object):
self._total_segs = total_segs self._total_segs = total_segs
self._msg_seq_num = msg_seq_num self._msg_seq_num = msg_seq_num
self._addr = None self._addr = None
self._type = _SEGMENT_TYPE_DATA self._type = MessageSegment._SEGMENT_TYPE_DATA
# Make the header # Make the header
self._header = struct.pack(self._HEADER_TEMPLATE, _MAGIC, self._type, self._header = struct.pack(self._HEADER_TEMPLATE, self._MAGIC, self._type,
self._segno, self._total_segs, self._msg_seq_num, self._master_sha) self._segno, self._total_segs, self._msg_seq_num, self._master_sha)
def _new_from_data(self, addr, data): def _new_from_data(self, addr, data):
"""Verify and construct a new message segment from network data.""" """Verify and construct a new message segment from network data."""
if len(data) < _HEADER_LEN + 1: if len(data) < self._HEADER_LEN + 1:
raise ValueError("Segment is less then minimum required length") raise ValueError("Segment is less then minimum required length")
stream = StringIO.StringIO(data) stream = StringIO.StringIO(data)
self._stime = None self._stime = None
@ -86,17 +111,17 @@ class MessageSegment(object):
stream.read(header_size)) stream.read(header_size))
# Sanity checks on the message attributes # Sanity checks on the message attributes
if seg_type != _SEGMENT_TYPE_DATA and seg_type != _SEGMENT_TYPE_RETRANSMIT: if not MessageSegment.is_type_valid(seg_type):
raise ValueError("Segment has invalid type.") raise ValueError("Segment has invalid type.")
if seg_type == _SEGMENT_TYPE_RETRANSMIT: if MessageSegment.is_data_type(seg_type):
if segno != 1 or total_segs != 1: if segno != 1 or total_segs != 1:
raise ValueError("Retransmission request messages must have only one segment.") raise ValueError("Retransmission request messages must have only one segment.")
if magic != _MAGIC: if magic != self._MAGIC:
raise ValueError("Segment does not have the correct magic.") raise ValueError("Segment does not have the correct magic.")
if self._data_len < 1: if self._data_len < 1:
raise ValueError("Segment must have some data.") raise ValueError("Segment must have some data.")
if self._data_len > _MTU: if self._data_len > self._MTU:
raise ValueError("Data length must not be larger than the MTU (%s)." % _MTU) raise ValueError("Data length must not be larger than the MTU (%s)." % self._MTU)
if segno < 1: if segno < 1:
raise ValueError("Segment number must be greater than 0.") raise ValueError("Segment number must be greater than 0.")
if segno > total_segs: if segno > total_segs:
@ -161,6 +186,9 @@ class MostlyReliablePipe(object):
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing """Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
delivery or receipt, just a better effort than no effort at all.""" delivery or receipt, just a better effort than no effort at all."""
_UDP_MSG_SIZE = MessageSegment.mtu() + MessageSegment.header_len()
_SEGMENT_TTL = 120 # 2 minutes
def __init__(self, local_addr, remote_addr, port, data_cb, user_data=None): def __init__(self, local_addr, remote_addr, port, data_cb, user_data=None):
self._local_addr = local_addr self._local_addr = local_addr
self._remote_addr = remote_addr self._remote_addr = remote_addr
@ -208,7 +236,7 @@ class MostlyReliablePipe(object):
# Watch the listener socket for data # Watch the listener socket for data
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
gobject.timeout_add(120000, self._segment_ttl_worker) gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
self._started = True self._started = True
@ -216,7 +244,7 @@ class MostlyReliablePipe(object):
"""Cull already-sent message segments that are past their TTL.""" """Cull already-sent message segments that are past their TTL."""
now = time.time() now = time.time()
for segment in self._sent[:]: for segment in self._sent[:]:
if segment.stime() < now - _MSG_TTL: if segment.stime() < now - self._SEGMENT_TTL:
self._sent.remove(segment) self._sent.remove(segment)
return True return True
@ -271,10 +299,10 @@ class MostlyReliablePipe(object):
# 20: total data sha1 # 20: total data sha1
# 2: segment number # 2: segment number
data = segment.data() data = segment.data()
if len(data) != 22: template = "@ H20sH"
if len(data) != struct.calcsize(template):
print "Bad retransmission request message format." print "Bad retransmission request message format."
# Native byte-order since the receive bits already unpacked it for us # Native byte-order since the receive bits already unpacked it for us
template = "@ H20sH"
(msg_seq_num, master_sha, segno) = struct.unpack(template, data) (msg_seq_num, master_sha, segno) = struct.unpack(template, data)
def _handle_incoming_data(self, source, condition): def _handle_incoming_data(self, source, condition):
@ -283,12 +311,13 @@ class MostlyReliablePipe(object):
if not (condition & gobject.IO_IN): if not (condition & gobject.IO_IN):
return True return True
msg = {} msg = {}
data, addr = source.recvfrom(_MTU + _HEADER_LEN) data, addr = source.recvfrom(self._UDP_MSG_SIZE)
try: try:
segment = MessageSegment.new_from_data(addr, data) segment = MessageSegment.new_from_data(addr, data)
if segment.segment_type() == _SEGMENT_TYPE_DATA: stype = segment.segment_type()
if MessageSegment.is_data_type(stype):
self._process_incoming_data(segment) self._process_incoming_data(segment)
elif segment.segment_type() == _SEGMENT_TYPE_RETRANSMIT: elif MessageSegment.is_retransmit_type(stype):
self._process_retransmit_request(segment) self._process_retransmit_request(segment)
except ValueError, exc: except ValueError, exc:
print "Bad segment: %s" % exc print "Bad segment: %s" % exc
@ -310,17 +339,18 @@ class MostlyReliablePipe(object):
# Split up the data into segments # Split up the data into segments
left = length = len(data) left = length = len(data)
nmessages = length / _MTU mtu = MessageSegment.mtu()
if length % _MTU > 0: nmessages = length / mtu
if length % mtu > 0:
nmessages = nmessages + 1 nmessages = nmessages + 1
msg_num = 1 msg_num = 1
while left > 0: while left > 0:
msg = MessageSegment.new_from_parts(self._seq_counter, msg_num, msg = MessageSegment.new_from_parts(self._seq_counter, msg_num,
nmessages, data[:_MTU], master_sha) nmessages, data[:mtu], master_sha)
self._outgoing.append(msg) self._outgoing.append(msg)
msg_num = msg_num + 1 msg_num = msg_num + 1
data = data[_MTU:] data = data[mtu:]
left = left - _MTU left = left - mtu
if len(self._outgoing) > 0 and self._worker == 0: if len(self._outgoing) > 0 and self._worker == 0:
self._worker = gobject.idle_add(self._send_worker) self._worker = gobject.idle_add(self._send_worker)