1395 lines
58 KiB
Python
1395 lines
58 KiB
Python
# Copyright (C) 2006, Red Hat, Inc.
|
|
#
|
|
# This library is free software; you can redistribute it and/or
|
|
# modify it under the terms of the GNU Lesser General Public
|
|
# License as published by the Free Software Foundation; either
|
|
# version 2 of the License, or (at your option) any later version.
|
|
#
|
|
# This library is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
# Lesser General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public
|
|
# License along with this library; if not, write to the
|
|
# Free Software Foundation, Inc., 59 Temple Place - Suite 330,
|
|
# Boston, MA 02111-1307, USA.
|
|
|
|
# FIXME tests use initialized variables, any better
|
|
# what to shut up pylint for those?
|
|
# pylint: disable-msg = W0612
|
|
|
|
import socket
|
|
import time
|
|
import sha
|
|
import struct
|
|
import StringIO
|
|
import binascii
|
|
import random
|
|
|
|
import gtk
|
|
import gobject
|
|
|
|
|
|
def _stringify_sha(sha_hash):
|
|
print_sha = ""
|
|
for char in sha_hash:
|
|
print_sha = print_sha + binascii.b2a_hex(char)
|
|
return print_sha
|
|
|
|
def _sha_data(data):
|
|
sha_hash = sha.new()
|
|
sha_hash.update(data)
|
|
return sha_hash.digest()
|
|
|
|
_UDP_DATAGRAM_SIZE = 512
|
|
|
|
class SegmentBase(object):
|
|
_MAGIC = 0xbaea4304
|
|
|
|
# 4: magic (0xbaea4304)
|
|
# 1: type
|
|
# 2: segment number
|
|
# 2: total segments
|
|
# 2: message sequence number
|
|
#20: total data sha1
|
|
_HEADER_TEMPLATE = "! IbHHH20s"
|
|
_HEADER_LEN = struct.calcsize(_HEADER_TEMPLATE)
|
|
_MTU = _UDP_DATAGRAM_SIZE - _HEADER_LEN
|
|
|
|
# Message segment packet types
|
|
_SEGMENT_TYPE_DATA = 0
|
|
_SEGMENT_TYPE_RETRANSMIT = 1
|
|
_SEGMENT_TYPE_ACK = 2
|
|
|
|
def magic():
|
|
return SegmentBase._MAGIC
|
|
magic = staticmethod(magic)
|
|
|
|
def header_template():
|
|
return SegmentBase._HEADER_TEMPLATE
|
|
header_template = staticmethod(header_template)
|
|
|
|
def type_data():
|
|
return SegmentBase._SEGMENT_TYPE_DATA
|
|
type_data = staticmethod(type_data)
|
|
|
|
def type_retransmit():
|
|
return SegmentBase._SEGMENT_TYPE_RETRANSMIT
|
|
type_retransmit = staticmethod(type_retransmit)
|
|
|
|
def type_ack():
|
|
return SegmentBase._SEGMENT_TYPE_ACK
|
|
type_ack = staticmethod(type_ack)
|
|
|
|
def header_len():
|
|
"""Return the header size of SegmentBase packets."""
|
|
return SegmentBase._HEADER_LEN
|
|
header_len = staticmethod(header_len)
|
|
|
|
def mtu():
|
|
"""Return the SegmentBase packet MTU."""
|
|
return SegmentBase._MTU
|
|
mtu = staticmethod(mtu)
|
|
|
|
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
|
self._type = None
|
|
self._transmits = 0
|
|
self._last_transmit = 0
|
|
self._data = None
|
|
self._data_len = 0
|
|
self.userdata = None
|
|
self._stime = time.time()
|
|
self._addr = None
|
|
|
|
# Sanity checks on the message attributes
|
|
if not segno or not isinstance(segno, int):
|
|
raise ValueError("Segment number must be in integer.")
|
|
if segno < 1 or segno > 65535:
|
|
raise ValueError("Segment number must be between 1 and 65535 inclusive.")
|
|
if not total_segs or not isinstance(total_segs, int):
|
|
raise ValueError("Message segment total must be an integer.")
|
|
if total_segs < 1 or total_segs > 65535:
|
|
raise ValueError("Message must have between 1 and 65535 segments inclusive.")
|
|
if segno > total_segs:
|
|
raise ValueError("Segment number cannot be larger than message segment total.")
|
|
if not msg_seq_num or not isinstance(msg_seq_num, int):
|
|
raise ValueError("Message sequnce number must be an integer.")
|
|
if msg_seq_num < 1 or msg_seq_num > 65535:
|
|
raise ValueError("Message sequence number must be between 1 and 65535 inclusive.")
|
|
if not master_sha or not isinstance(master_sha, str) or len(master_sha) != 20:
|
|
raise ValueError("Message SHA1 checksum invalid.")
|
|
|
|
self._segno = segno
|
|
self._total_segs = total_segs
|
|
self._msg_seq_num = msg_seq_num
|
|
self._master_sha = master_sha
|
|
|
|
def _validate_address(addr):
|
|
if not addr or not isinstance(addr, tuple):
|
|
raise ValueError("Address must be a tuple.")
|
|
if len(addr) != 2 or not isinstance(addr[0], str) or not isinstance(addr[1], int):
|
|
raise ValueError("Address format was invalid.")
|
|
if addr[1] < 1 or addr[1] > 65535:
|
|
raise ValueError("Address port was invalid.")
|
|
_validate_address = staticmethod(_validate_address)
|
|
|
|
def new_from_data(addr, data):
|
|
"""Static constructor for creation from a packed data stream."""
|
|
SegmentBase._validate_address(addr)
|
|
|
|
# Verify minimum length
|
|
if not data:
|
|
raise ValueError("Segment data is invalid.")
|
|
data_len = len(data)
|
|
if data_len < SegmentBase.header_len() + 1:
|
|
raise ValueError("Segment is less then minimum required length")
|
|
if data_len > _UDP_DATAGRAM_SIZE:
|
|
raise ValueError("Segment data is larger than allowed.")
|
|
stream = StringIO.StringIO(data)
|
|
|
|
# Determine and verify the length of included data
|
|
stream.seek(0, 2)
|
|
data_len = stream.tell() - SegmentBase._HEADER_LEN
|
|
stream.seek(0)
|
|
|
|
if data_len < 1:
|
|
raise ValueError("Segment must have some data.")
|
|
if data_len > SegmentBase._MTU:
|
|
raise ValueError("Data length must not be larger than the MTU (%s)." % SegmentBase._MTU)
|
|
|
|
# Read the first header attributes
|
|
(magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(SegmentBase._HEADER_TEMPLATE,
|
|
stream.read(SegmentBase._HEADER_LEN))
|
|
|
|
# Sanity checks on the message attributes
|
|
if magic != SegmentBase._MAGIC:
|
|
raise ValueError("Segment does not have the correct magic.")
|
|
|
|
# if the segment is the only one in the message, validate the data
|
|
if segno == 1 and total_segs == 1:
|
|
data_sha = _sha_data(stream.read(data_len))
|
|
if data_sha != master_sha:
|
|
raise ValueError("Single segment message SHA checksums didn't match.")
|
|
stream.seek(SegmentBase._HEADER_LEN)
|
|
|
|
if seg_type == SegmentBase._SEGMENT_TYPE_DATA:
|
|
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
|
|
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
|
|
segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
|
|
elif seg_type == SegmentBase._SEGMENT_TYPE_ACK:
|
|
segment = AckSegment(segno, total_segs, msg_seq_num, master_sha)
|
|
else:
|
|
raise ValueError("Segment has invalid type.")
|
|
|
|
# Segment specific data interpretation
|
|
segment._addr = addr
|
|
segment._unpack_data(stream, data_len)
|
|
|
|
return segment
|
|
new_from_data = staticmethod(new_from_data)
|
|
|
|
def stime(self):
|
|
return self._stime
|
|
|
|
def address(self):
|
|
return self._addr
|
|
|
|
def segment_number(self):
|
|
return self._segno
|
|
|
|
def total_segments(self):
|
|
return self._total_segs
|
|
|
|
def message_sequence_number(self):
|
|
return self._msg_seq_num
|
|
|
|
def data(self):
|
|
return self._data
|
|
|
|
def master_sha(self):
|
|
return self._master_sha
|
|
|
|
def segment_type(self):
|
|
return self._type
|
|
|
|
def packetize(self):
|
|
"""Return a correctly formatted message that can be immediately sent."""
|
|
header = struct.pack(self._HEADER_TEMPLATE, self._MAGIC, self._type,
|
|
self._segno, self._total_segs, self._msg_seq_num, self._master_sha)
|
|
return header + self._data
|
|
|
|
def transmits(self):
|
|
return self._transmits
|
|
|
|
def inc_transmits(self):
|
|
self._transmits = self._transmits + 1
|
|
self._last_transmit = time.time()
|
|
|
|
def last_transmit(self):
|
|
return self._last_transmit
|
|
|
|
class DataSegment(SegmentBase):
|
|
"""A message segment that encapsulates random data."""
|
|
|
|
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
|
SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
|
|
self._type = SegmentBase._SEGMENT_TYPE_DATA
|
|
|
|
def _get_template_for_len(length):
|
|
return "! %ds" % length
|
|
_get_template_for_len = staticmethod(_get_template_for_len)
|
|
|
|
def _unpack_data(self, stream, data_len):
|
|
"""Unpack the data stream, called by constructor."""
|
|
self._data_len = data_len
|
|
template = DataSegment._get_template_for_len(self._data_len)
|
|
self._data = struct.unpack(template, stream.read(self._data_len))[0]
|
|
|
|
def new_from_parts(segno, total_segs, msg_seq_num, master_sha, data):
|
|
"""Construct a new message segment from individual attributes."""
|
|
if not data:
|
|
raise ValueError("Must have valid data.")
|
|
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
|
|
segment._data_len = len(data)
|
|
template = DataSegment._get_template_for_len(segment._data_len)
|
|
segment._data = struct.pack(template, data)
|
|
return segment
|
|
new_from_parts = staticmethod(new_from_parts)
|
|
|
|
|
|
class RetransmitSegment(SegmentBase):
|
|
"""A message segment that encapsulates a retransmission request."""
|
|
|
|
# Retransmission data format:
|
|
# 2: message sequence number
|
|
# 20: total data sha1
|
|
# 2: segment number
|
|
_RT_DATA_TEMPLATE = "! H20sH"
|
|
_RT_DATA_LEN = struct.calcsize(_RT_DATA_TEMPLATE)
|
|
|
|
def data_template():
|
|
return RetransmitSegment._RT_DATA_TEMPLATE
|
|
data_template = staticmethod(data_template)
|
|
|
|
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
|
"""Should not be called directly."""
|
|
if segno != 1 or total_segs != 1:
|
|
raise ValueError("Retransmission request messages must have only one segment.")
|
|
|
|
SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
|
|
self._type = SegmentBase._SEGMENT_TYPE_RETRANSMIT
|
|
|
|
def _verify_data(rt_msg_seq_num, rt_master_sha, rt_segment_number):
|
|
# Sanity checks on the message attributes
|
|
if not rt_segment_number or not isinstance(rt_segment_number, int):
|
|
raise ValueError("RT Segment number must be in integer.")
|
|
if rt_segment_number < 1 or rt_segment_number > 65535:
|
|
raise ValueError("RT Segment number must be between 1 and 65535 inclusive.")
|
|
if not rt_msg_seq_num or not isinstance(rt_msg_seq_num, int):
|
|
raise ValueError("RT Message sequnce number must be an integer.")
|
|
if rt_msg_seq_num < 1 or rt_msg_seq_num > 65535:
|
|
raise ValueError("RT Message sequence number must be between 1 and 65535 inclusive.")
|
|
if not rt_master_sha or not isinstance(rt_master_sha, str) or len(rt_master_sha) != 20:
|
|
raise ValueError("RT Message SHA1 checksum invalid.")
|
|
_verify_data = staticmethod(_verify_data)
|
|
|
|
def _make_rtms_data(rt_msg_seq_num, rt_master_sha, rt_segment_number):
|
|
"""Pack retransmission request payload."""
|
|
data = struct.pack(RetransmitSegment._RT_DATA_TEMPLATE, rt_msg_seq_num,
|
|
rt_master_sha, rt_segment_number)
|
|
return (data, _sha_data(data))
|
|
_make_rtms_data = staticmethod(_make_rtms_data)
|
|
|
|
def new_from_parts(addr, msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number):
|
|
"""Static constructor for creation from individual attributes."""
|
|
|
|
RetransmitSegment._verify_data(rt_msg_seq_num, rt_master_sha, rt_segment_number)
|
|
(data, data_sha) = RetransmitSegment._make_rtms_data(rt_msg_seq_num,
|
|
rt_master_sha, rt_segment_number)
|
|
segment = RetransmitSegment(1, 1, msg_seq_num, data_sha)
|
|
segment._data_len = RetransmitSegment._RT_DATA_LEN
|
|
segment._data = data
|
|
SegmentBase._validate_address(addr)
|
|
segment._addr = addr
|
|
|
|
segment._rt_msg_seq_num = rt_msg_seq_num
|
|
segment._rt_master_sha = rt_master_sha
|
|
segment._rt_segment_number = rt_segment_number
|
|
return segment
|
|
new_from_parts = staticmethod(new_from_parts)
|
|
|
|
def _unpack_data(self, stream, data_len):
|
|
if data_len != self._RT_DATA_LEN:
|
|
raise ValueError("Retransmission request data had invalid length.")
|
|
data = stream.read(data_len)
|
|
(rt_msg_seq_num, rt_master_sha, rt_seg_no) = struct.unpack(self._RT_DATA_TEMPLATE, data)
|
|
RetransmitSegment._verify_data(rt_msg_seq_num, rt_master_sha, rt_seg_no)
|
|
|
|
self._data = data
|
|
self._data_len = data_len
|
|
self._rt_msg_seq_num = rt_msg_seq_num
|
|
self._rt_master_sha = rt_master_sha
|
|
self._rt_segment_number = rt_seg_no
|
|
|
|
def rt_msg_seq_num(self):
|
|
return self._rt_msg_seq_num
|
|
|
|
def rt_master_sha(self):
|
|
return self._rt_master_sha
|
|
|
|
def rt_segment_number(self):
|
|
return self._rt_segment_number
|
|
|
|
|
|
class AckSegment(SegmentBase):
|
|
"""A message segment that encapsulates a message acknowledgement."""
|
|
|
|
# Ack data format:
|
|
# 2: acked message sequence number
|
|
# 20: acked message total data sha1
|
|
# 4: acked message source IP address
|
|
_ACK_DATA_TEMPLATE = "! H20s4s"
|
|
_ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE)
|
|
|
|
def data_template():
|
|
return AckSegment._ACK_DATA_TEMPLATE
|
|
data_template = staticmethod(data_template)
|
|
|
|
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
|
"""Should not be called directly."""
|
|
if segno != 1 or total_segs != 1:
|
|
raise ValueError("Acknowledgement messages must have only one segment.")
|
|
|
|
SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
|
|
self._type = SegmentBase._SEGMENT_TYPE_ACK
|
|
|
|
def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr):
|
|
# Sanity checks on the message attributes
|
|
if not ack_msg_seq_num or not isinstance(ack_msg_seq_num, int):
|
|
raise ValueError("Ack message sequnce number must be an integer.")
|
|
if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535:
|
|
raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.")
|
|
if not ack_master_sha or not isinstance(ack_master_sha, str) or len(ack_master_sha) != 20:
|
|
raise ValueError("Ack message SHA1 checksum invalid.")
|
|
if not isinstance(ack_addr, str):
|
|
raise ValueError("Ack message invalid address type.")
|
|
try:
|
|
foo = socket.inet_aton(ack_addr)
|
|
except socket.error:
|
|
raise ValueError("Ack message invalid address.")
|
|
_verify_data = staticmethod(_verify_data)
|
|
|
|
def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr):
|
|
"""Pack an ack payload."""
|
|
addr_data = socket.inet_aton(ack_addr)
|
|
data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num,
|
|
ack_master_sha, addr_data)
|
|
return (data, _sha_data(data))
|
|
_make_ack_data = staticmethod(_make_ack_data)
|
|
|
|
def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr):
|
|
"""Static constructor for creation from individual attributes."""
|
|
|
|
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
|
|
(data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num,
|
|
ack_master_sha, ack_addr)
|
|
segment = AckSegment(1, 1, msg_seq_num, data_sha)
|
|
segment._data_len = AckSegment._ACK_DATA_LEN
|
|
segment._data = data
|
|
SegmentBase._validate_address(addr)
|
|
segment._addr = addr
|
|
|
|
segment._ack_msg_seq_num = ack_msg_seq_num
|
|
segment._ack_master_sha = ack_master_sha
|
|
segment._ack_addr = ack_addr
|
|
return segment
|
|
new_from_parts = staticmethod(new_from_parts)
|
|
|
|
def _unpack_data(self, stream, data_len):
|
|
if data_len != self._ACK_DATA_LEN:
|
|
raise ValueError("Ack segment data had invalid length.")
|
|
data = stream.read(data_len)
|
|
(ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data)
|
|
try:
|
|
ack_addr = socket.inet_ntoa(ack_addr_data)
|
|
except socket.error:
|
|
raise ValueError("Ack segment data had invalid address.")
|
|
AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
|
|
|
|
self._data = data
|
|
self._data_len = data_len
|
|
self._ack_msg_seq_num = ack_msg_seq_num
|
|
self._ack_master_sha = ack_master_sha
|
|
self._ack_addr = ack_addr
|
|
|
|
def ack_msg_seq_num(self):
|
|
return self._ack_msg_seq_num
|
|
|
|
def ack_master_sha(self):
|
|
return self._ack_master_sha
|
|
|
|
def ack_addr(self):
|
|
return self._ack_addr
|
|
|
|
class Message(object):
|
|
"""Tracks an entire message object, which is composed of a number
|
|
of individual segments."""
|
|
def __init__(self, src_addr, msg_seq_num, msg_sha, total_segments):
|
|
self._rt_target = 0
|
|
self._next_rt_time = 0
|
|
self._last_incoming_time = 0
|
|
self._segments = {}
|
|
self._complete = False
|
|
self._dispatched_time = 0
|
|
self._data = None
|
|
self._data_sha = None
|
|
self._src_addr = src_addr
|
|
self._msg_seq_num = msg_seq_num
|
|
self._msg_sha = msg_sha
|
|
self._total_segments = total_segments
|
|
self._rt_tries = {}
|
|
for i in range(1, self._total_segments + 1):
|
|
self._rt_tries[i] = 0
|
|
|
|
def __del__(self):
|
|
self.clear()
|
|
|
|
def sha(self):
|
|
return self._msg_sha
|
|
|
|
def source_address(self):
|
|
return self._src_addr
|
|
|
|
def clear(self):
|
|
for key in self._segments.keys()[:]:
|
|
del self._segments[key]
|
|
del self._rt_tries[key]
|
|
self._segments = {}
|
|
self._rt_tries = {}
|
|
|
|
def has_segment(self, segno):
|
|
return self._segments.has_key(segno)
|
|
|
|
def first_missing(self):
|
|
for i in range(1, self._total_segments + 1):
|
|
if not self._segments.has_key(i):
|
|
return i
|
|
return 0
|
|
|
|
_DEF_RT_REQUEST_INTERVAL = 0.09 # 70ms (in seconds)
|
|
def update_rt_wait(self, now):
|
|
"""now argument should be in seconds."""
|
|
wait = self._DEF_RT_REQUEST_INTERVAL
|
|
if self._last_incoming_time > now - 0.02:
|
|
msg_completeness = float(len(self._segments)) / float(self._total_segments)
|
|
wait = wait + (self._DEF_RT_REQUEST_INTERVAL * (1.0 - msg_completeness))
|
|
self._next_rt_time = now + wait
|
|
|
|
def add_segment(self, segment):
|
|
if self.complete():
|
|
return
|
|
segno = segment.segment_number()
|
|
if self._segments.has_key(segno):
|
|
return
|
|
self._segments[segno] = segment
|
|
self._rt_tries[segno] = 0
|
|
now = time.time()
|
|
self._last_incoming_time = now
|
|
|
|
num_segs = len(self._segments)
|
|
if num_segs == self._total_segments:
|
|
self._complete = True
|
|
self._next_rt_time = 0
|
|
self._data = ''
|
|
for seg in self._segments.values():
|
|
self._data = self._data + seg.data()
|
|
self._data_sha = _sha_data(self._data)
|
|
elif segno == num_segs or num_segs == 1:
|
|
# If we're not missing segments, push back retransmit request
|
|
self.update_rt_wait(now)
|
|
|
|
def get_retransmit_message(self, msg_seq_num, segno):
|
|
if segno < 1 or segno > self._total_segments:
|
|
return None
|
|
seg = RetransmitSegment.new_from_parts(self._src_addr, msg_seq_num,
|
|
self._msg_seq_num, self._msg_sha, segno)
|
|
self._rt_tries[segno] = self._rt_tries[segno] + 1
|
|
self.update_rt_wait(time.time())
|
|
return seg
|
|
|
|
def complete(self):
|
|
return self._complete
|
|
|
|
def dispatch_time(self):
|
|
return self._dispatch_time
|
|
|
|
def set_dispatch_time(self):
|
|
self._dispatch_time = time.time()
|
|
|
|
def data(self):
|
|
return (self._data, self._data_sha)
|
|
|
|
def last_incoming_time(self):
|
|
return self._last_incoming_time
|
|
|
|
def next_rt_time(self):
|
|
return self._next_rt_time
|
|
|
|
def rt_tries(self, segno):
|
|
if self._rt_tries.has_key(segno):
|
|
return self._rt_tries[segno]
|
|
return 0
|
|
|
|
|
|
def _get_local_interfaces():
|
|
import array
|
|
import struct
|
|
import fcntl
|
|
import socket
|
|
|
|
max_possible = 4
|
|
bytes = max_possible * 32
|
|
SIOCGIFCONF = 0x8912
|
|
names = array.array('B', '\0' * bytes)
|
|
|
|
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
ifreq = struct.pack('iL', bytes, names.buffer_info()[0])
|
|
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq)
|
|
sockfd.close()
|
|
|
|
outbytes = struct.unpack('iL', result)[0]
|
|
namestr = names.tostring()
|
|
|
|
return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)]
|
|
|
|
def _get_local_ip_addresses():
|
|
"""Call Linux specific bits to retrieve our own IP address."""
|
|
import socket
|
|
import sys
|
|
import fcntl
|
|
import struct
|
|
|
|
intfs = _get_local_interfaces()
|
|
|
|
ips = []
|
|
SIOCGIFADDR = 0x8915
|
|
sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
for intf in intfs:
|
|
if intf == "lo":
|
|
continue
|
|
try:
|
|
ifreq = (intf + '\0'*32)[:32]
|
|
result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq)
|
|
addr = socket.inet_ntoa(result[20:24])
|
|
ips.append(addr)
|
|
except IOError, exc:
|
|
print "Error getting IP address: %s" % exc
|
|
sockfd.close()
|
|
return ips
|
|
|
|
|
|
class MostlyReliablePipe(object):
|
|
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
|
|
delivery or receipt, just a better effort than no effort at all."""
|
|
|
|
_UDP_MSG_SIZE = SegmentBase.mtu() + SegmentBase.header_len()
|
|
_SEGMENT_TTL = 120 # 2 minutes
|
|
|
|
def __init__(self, local_addr, remote_addr, port, data_cb, user_data=None):
|
|
self._local_addr = local_addr
|
|
self._remote_addr = remote_addr
|
|
self._port = port
|
|
self._data_cb = data_cb
|
|
self._user_data = user_data
|
|
self._started = False
|
|
self._send_worker = 0
|
|
self._seq_counter = 0
|
|
self._drop_prob = 0
|
|
self._rt_check_worker_id = 0
|
|
|
|
self._outgoing = []
|
|
self._sent = {}
|
|
|
|
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
|
|
self._dispatched = {}
|
|
self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp
|
|
self._ack_check_worker_id = 0
|
|
|
|
self._local_ips = _get_local_ip_addresses()
|
|
|
|
self._setup_listener()
|
|
self._setup_sender()
|
|
|
|
def __del__(self):
|
|
if self._send_worker > 0:
|
|
gobject.source_remove(self._send_worker)
|
|
self._send_worker = 0
|
|
if self._rt_check_worker_id > 0:
|
|
gobject.source_remove(self._rt_check_worker_id)
|
|
self._rt_check_worker_id = 0
|
|
if self._ack_check_worker_id > 0:
|
|
gobject.source_remove(self._ack_check_worker_id)
|
|
self._ack_check_worker_id = 0
|
|
|
|
def _setup_sender(self):
|
|
"""Setup the send socket for multicast."""
|
|
self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
# Make the socket multicast-aware, and set TTL.
|
|
self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit
|
|
|
|
def _setup_listener(self):
|
|
"""Set up the listener socket for multicast traffic."""
|
|
# Listener socket
|
|
self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|
|
# Set some options to make it multicast-friendly
|
|
self._listen_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 20)
|
|
self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)
|
|
|
|
def start(self):
|
|
"""Let the listener socket start listening for network data."""
|
|
# Set some more multicast options
|
|
self._listen_sock.bind((self._local_addr, self._port))
|
|
self._listen_sock.settimeout(2)
|
|
# Disable for now to try to fix "cannot assign requested address" errors
|
|
# intf = socket.gethostbyname(socket.gethostname())
|
|
# self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF,
|
|
# socket.inet_aton(intf) + socket.inet_aton('0.0.0.0'))
|
|
self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP,
|
|
socket.inet_aton(self._remote_addr) + socket.inet_aton('0.0.0.0'))
|
|
|
|
# Watch the listener socket for data
|
|
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
|
|
gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
|
|
self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker)
|
|
self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker)
|
|
|
|
self._started = True
|
|
|
|
def _segment_ttl_worker(self):
|
|
"""Cull already-sent message segments that are past their TTL."""
|
|
now = time.time()
|
|
for key in self._sent.keys()[:]:
|
|
segment = self._sent[key]
|
|
if segment.stime() < now - self._SEGMENT_TTL:
|
|
if segment.userdata:
|
|
gobject.source_remove(segment.userdata)
|
|
del self._sent[key]
|
|
|
|
# Cull incomplete incoming segment chains that haven't gotten any data
|
|
# for a long time either
|
|
for msg_key in self._incoming.keys()[:]:
|
|
message = self._incoming[msg_key]
|
|
if message.last_incoming_time() < now - self._SEGMENT_TTL:
|
|
del self._incoming[msg_key]
|
|
|
|
# Remove already received and dispatched messages after a while
|
|
for msg_key in self._dispatched.keys()[:]:
|
|
message = self._dispatched[msg_key]
|
|
if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
|
|
del self._dispatched[msg_key]
|
|
|
|
# Remove received acks after a while
|
|
for ack_key in self._acks.keys()[:]:
|
|
ack_time = self._acks[ack_key]
|
|
if ack_time < now - (self._SEGMENT_TTL*2):
|
|
del self._acks[ack_key]
|
|
|
|
return True
|
|
|
|
_MAX_SEGMENT_RETRIES = 10
|
|
def _retransmit_request(self, message):
|
|
"""Returns true if the message has exceeded it's retry limit."""
|
|
first_missing = message.first_missing()
|
|
if first_missing > 0:
|
|
num_retries = message.rt_tries(first_missing)
|
|
if num_retries > self._MAX_SEGMENT_RETRIES:
|
|
return True
|
|
msg_seq = self._next_msg_seq()
|
|
seg = message.get_retransmit_message(msg_seq, first_missing)
|
|
if seg:
|
|
print "(MRP): Requesting retransmit of %d by %s" % (first_missing, message.source_address())
|
|
self._outgoing.append(seg)
|
|
self._schedule_send_worker()
|
|
return False
|
|
|
|
def _retransmit_check_worker(self):
|
|
"""Periodically check for and send retransmit requests for message
|
|
segments that got lost."""
|
|
try:
|
|
now = time.time()
|
|
for key in self._incoming.keys()[:]:
|
|
message = self._incoming[key]
|
|
if message.complete():
|
|
continue
|
|
next_rt = message.next_rt_time()
|
|
if next_rt == 0 or next_rt > now:
|
|
continue
|
|
if self._retransmit_request(message):
|
|
# Kill the message, too many retries
|
|
print "(MRP): Dropped message %s, exceeded retries." % _stringify_sha(message.sha())
|
|
self._dispatched[key] = message
|
|
message.set_dispatch_time()
|
|
del self._incoming[key]
|
|
except KeyboardInterrupt:
|
|
return False
|
|
return True
|
|
|
|
def _process_incoming_data(self, segment):
|
|
"""Handle a new message segment. First checks if there is only one
|
|
segment to the message, and if the checksum from the header matches
|
|
that computed from the data, dispatches it. Otherwise, it adds the
|
|
new segment to the list of other segments for that message, and
|
|
checks to see if the message is complete. If all segments are present,
|
|
the message is reassembled and dispatched."""
|
|
|
|
msg_sha = segment.master_sha()
|
|
nsegs = segment.total_segments()
|
|
addr = segment.address()
|
|
segno = segment.segment_number()
|
|
|
|
msg_seq_num = segment.message_sequence_number()
|
|
msg_key = (addr[0], msg_seq_num, msg_sha, nsegs)
|
|
|
|
if self._dispatched.has_key(msg_key):
|
|
# We already dispatched this message, this segment is useless
|
|
return
|
|
# First segment in the message
|
|
if not self._incoming.has_key(msg_key):
|
|
self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
|
|
# Acknowledge the message if it didn't come from us
|
|
if addr[0] not in self._local_ips:
|
|
ack_key = (msg_seq_num, msg_sha, addr[0])
|
|
if not self._acks.has_key(ack_key):
|
|
self._send_ack_for_message(msg_seq_num, msg_sha, addr[0])
|
|
|
|
message = self._incoming[msg_key]
|
|
# Look for a dupe, and if so, drop the new segment
|
|
if message.has_segment(segno):
|
|
return
|
|
message.add_segment(segment)
|
|
|
|
# Dispatch the message if all segments are present and the sha is correct
|
|
if message.complete():
|
|
(msg_data, complete_data_sha) = message.data()
|
|
if msg_sha == complete_data_sha:
|
|
self._data_cb(addr, msg_data, self._user_data)
|
|
self._dispatched[msg_key] = message
|
|
message.set_dispatch_time()
|
|
del self._incoming[msg_key]
|
|
return
|
|
|
|
def _segment_retransmit_cb(self, key, segment):
|
|
"""Add a segment ot the outgoing queue and schedule its transmission."""
|
|
del self._sent[key]
|
|
self._outgoing.append(segment)
|
|
self._schedule_send_worker()
|
|
return False
|
|
|
|
def _schedule_segment_retransmit(self, key, segment, when, now):
|
|
"""Schedule retransmission of a segment if one is not already scheduled."""
|
|
if segment.userdata:
|
|
# Already scheduled for retransmit
|
|
return
|
|
|
|
if when <= now:
|
|
# Immediate retransmission
|
|
self._segment_retransmit_cb(key, segment)
|
|
else:
|
|
# convert time to milliseconds
|
|
timeout = int((when - now) * 1000)
|
|
segment.userdata = gobject.timeout_add(timeout, self._segment_retransmit_cb,
|
|
key, segment)
|
|
|
|
_STD_RETRANSMIT_INTERVAL = 0.05 # 50ms (in seconds)
|
|
def _process_retransmit_request(self, segment):
|
|
"""Validate and process a retransmission request."""
|
|
key = (segment.rt_msg_seq_num(), segment.rt_master_sha(), segment.rt_segment_number())
|
|
if not self._sent.has_key(key):
|
|
# Either we don't know about the segment, or it was already culled
|
|
return
|
|
|
|
# Calculate next retransmission time and schedule packet for retransmit
|
|
segment = self._sent[key]
|
|
# only retransmit segments every 150ms or more
|
|
now = time.time()
|
|
next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
|
|
self._schedule_segment_retransmit(key, segment, next_transmit, now)
|
|
|
|
def _ack_check_worker(self):
|
|
"""Periodically check for messages that haven't received an ack
|
|
yet, and retransmit them."""
|
|
try:
|
|
now = time.time()
|
|
for key in self._sent.keys()[:]:
|
|
segment = self._sent[key]
|
|
# We only care about retransmitting the first segment
|
|
# of a message, since if other machines don't have the
|
|
# rest of the segments, they'll issue retransmit requests
|
|
if segment.segment_number() != 1:
|
|
continue
|
|
if segment.last_transmit() > now - 0.150: # 150ms
|
|
# Was just retransmitted recently, wait longer
|
|
# before retransmitting it
|
|
continue
|
|
ack_key = None
|
|
for ip in self._local_ips:
|
|
ack_key = (segment.message_sequence_number(), segment.master_sha(), ip)
|
|
if self._acks.has_key(ack_key):
|
|
break
|
|
ack_key = None
|
|
# If the segment already has been acked, don't send it
|
|
# again unless somebody explicitly requests a retransmit
|
|
if ack_key is not None:
|
|
continue
|
|
|
|
del self._sent[key]
|
|
self._outgoing.append(segment)
|
|
self._schedule_send_worker()
|
|
except KeyboardInterrupt:
|
|
return False
|
|
return True
|
|
|
|
def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr):
|
|
"""Send an ack segment for a message."""
|
|
msg_seq_num = self._next_msg_seq()
|
|
full_remote_addr = (self._remote_addr, self._port)
|
|
ack = AckSegment.new_from_parts(full_remote_addr, msg_seq_num,
|
|
ack_msg_seq_num, ack_msg_sha, ack_addr)
|
|
self._outgoing.append(ack)
|
|
self._schedule_send_worker()
|
|
self._process_incoming_ack(ack)
|
|
|
|
def _process_incoming_ack(self, segment):
|
|
"""Save the ack so that we don't send an ack when we start getting the segments
|
|
the ack was acknowledging."""
|
|
# If the ack is supposed to be for a message we sent, only accept it if
|
|
# we actually sent the message to which it refers
|
|
ack_addr = segment.ack_addr()
|
|
ack_master_sha = segment.ack_master_sha()
|
|
ack_msg_seq_num = segment.ack_msg_seq_num()
|
|
if ack_addr in self._local_ips:
|
|
sent_key = (ack_msg_seq_num, ack_master_sha, 1)
|
|
if not self._sent.has_key(sent_key):
|
|
return
|
|
ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr)
|
|
if not self._acks.has_key(ack_key):
|
|
self._acks[ack_key] = time.time()
|
|
|
|
def set_drop_probability(self, prob=4):
|
|
"""Debugging function to randomly drop incoming packets.
|
|
The prob argument should be an integer between 1 and 10 to drop,
|
|
or 0 to drop none. Higher numbers drop more packets."""
|
|
if not isinstance(prob, int):
|
|
raise ValueError("Drop probability must be an integer.")
|
|
if prob < 1 or prob > 10:
|
|
raise ValueError("Drop probability must be between 1 and 10 inclusive.")
|
|
self._drop_prob = prob
|
|
|
|
def _handle_incoming_data(self, source, condition):
|
|
"""Handle incoming network data by making a message segment out of it
|
|
sending it off to the processing function."""
|
|
if not (condition & gobject.IO_IN):
|
|
return True
|
|
msg = {}
|
|
data, addr = source.recvfrom(self._UDP_MSG_SIZE)
|
|
|
|
should_drop = False
|
|
p = random.random() * 10.0
|
|
if self._drop_prob > 0 and p <= self._drop_prob:
|
|
should_drop = True
|
|
|
|
try:
|
|
segment = SegmentBase.new_from_data(addr, data)
|
|
if should_drop:
|
|
print "(MRP): Dropped segment %d." % segment.segment_number()
|
|
else:
|
|
stype = segment.segment_type()
|
|
if stype == SegmentBase.type_data():
|
|
self._process_incoming_data(segment)
|
|
elif stype == SegmentBase.type_retransmit():
|
|
self._process_retransmit_request(segment)
|
|
elif stype == SegmentBase.type_ack():
|
|
self._process_incoming_ack(segment)
|
|
except ValueError, exc:
|
|
print "(MRP): Bad segment: %s" % exc
|
|
return True
|
|
|
|
def _next_msg_seq(self):
|
|
self._seq_counter = self._seq_counter + 1
|
|
if self._seq_counter > 65535:
|
|
self._seq_counter = 1
|
|
return self._seq_counter
|
|
|
|
def send(self, data):
|
|
"""Break data up into chunks and queue for later transmission."""
|
|
if not self._started:
|
|
raise Exception("Can't send anything until started!")
|
|
|
|
msg_seq = self._next_msg_seq()
|
|
|
|
# Pack the data into network byte order
|
|
template = "! %ds" % len(str(data))
|
|
data = struct.pack(template, str(data))
|
|
master_sha = _sha_data(data)
|
|
|
|
# Split up the data into segments
|
|
left = length = len(data)
|
|
mtu = SegmentBase.mtu()
|
|
nmessages = length / mtu
|
|
if length % mtu > 0:
|
|
nmessages = nmessages + 1
|
|
seg_num = 1
|
|
while left > 0:
|
|
seg = DataSegment.new_from_parts(seg_num, nmessages,
|
|
msg_seq, master_sha, data[:mtu])
|
|
self._outgoing.append(seg)
|
|
seg_num = seg_num + 1
|
|
data = data[mtu:]
|
|
left = left - mtu
|
|
self._schedule_send_worker()
|
|
|
|
def _schedule_send_worker(self):
|
|
if len(self._outgoing) > 0 and self._send_worker == 0:
|
|
self._send_worker = gobject.timeout_add(50, self._send_worker_cb)
|
|
|
|
def _send_worker_cb(self):
|
|
"""Send all queued segments that have yet to be transmitted."""
|
|
self._send_worker = 0
|
|
nsent = 0
|
|
for segment in self._outgoing:
|
|
packet = segment.packetize()
|
|
segment.inc_transmits()
|
|
addr = (self._remote_addr, self._port)
|
|
if segment.address():
|
|
addr = segment.address()
|
|
self._send_sock.sendto(packet, addr)
|
|
if segment.userdata:
|
|
gobject.source_remove(segment.userdata)
|
|
segment.userdata = None # Retransmission GSource
|
|
key = (segment.message_sequence_number(), segment.master_sha(), segment.segment_number())
|
|
self._sent[key] = segment
|
|
nsent = nsent + 1
|
|
if nsent > 10:
|
|
break
|
|
self._outgoing = self._outgoing[nsent:]
|
|
if len(self._outgoing):
|
|
self._schedule_send_worker()
|
|
return False
|
|
|
|
|
|
#################################################################
|
|
# Tests
|
|
#################################################################
|
|
|
|
import unittest
|
|
|
|
|
|
class SegmentBaseTestCase(unittest.TestCase):
|
|
_DEF_SEGNO = 1
|
|
_DEF_TOT_SEGS = 5
|
|
_DEF_MSG_SEQ_NUM = 4556
|
|
_DEF_MASTER_SHA = "12345678901234567890"
|
|
_DEF_SEG_TYPE = 0
|
|
|
|
_DEF_ADDRESS = ('123.3.2.1', 3333)
|
|
_SEG_MAGIC = 0xbaea4304
|
|
|
|
|
|
class SegmentBaseInitTestCase(SegmentBaseTestCase):
|
|
def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg):
|
|
try:
|
|
seg = SegmentBase(segno, total_segs, msg_seq_num, master_sha)
|
|
except ValueError, exc:
|
|
pass
|
|
else:
|
|
self.fail("expected a ValueError for %s." % fail_msg)
|
|
|
|
def testSegmentBase(self):
|
|
assert SegmentBase.magic() == self._SEG_MAGIC, "Segment magic wasn't correct!"
|
|
assert SegmentBase.header_len() > 0, "header size was not greater than zero."
|
|
assert SegmentBase.mtu() > 0, "MTU was not greater than zero."
|
|
assert SegmentBase.mtu() + SegmentBase.header_len() == _UDP_DATAGRAM_SIZE, "MTU + header size didn't equal expected %d." % _UDP_DATAGRAM_SIZE
|
|
|
|
def testGoodInit(self):
|
|
seg = SegmentBase(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
|
assert seg.stime() < time.time(), "segment start time is less than now!"
|
|
assert not seg.address(), "Segment address was not None after init."
|
|
assert seg.segment_number() == self._DEF_SEGNO, "Segment number wasn't correct after init."
|
|
assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments wasn't correct after init."
|
|
assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number wasn't correct after init."
|
|
assert seg.master_sha() == self._DEF_MASTER_SHA, "Message master SHA wasn't correct after init."
|
|
assert seg.segment_type() == None, "Segment type was not None after init."
|
|
assert seg.transmits() == 0, "Segment transmits was not 0 after init."
|
|
assert seg.last_transmit() == 0, "Segment last transmit was not 0 after init."
|
|
assert seg.data() == None, "Segment data was not None after init."
|
|
|
|
def testSegmentNumber(self):
|
|
self._test_init_fail(0, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
|
self._test_init_fail(65536, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
|
self._test_init_fail(None, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
|
self._test_init_fail("", self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
|
|
|
def testTotalMessageSegmentNumber(self):
|
|
self._test_init_fail(self._DEF_SEGNO, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
|
self._test_init_fail(self._DEF_SEGNO, 65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
|
self._test_init_fail(self._DEF_SEGNO, None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
|
self._test_init_fail(self._DEF_SEGNO, "", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
|
|
|
def testMessageSequenceNumber(self):
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 0, self._DEF_MASTER_SHA, "invalid message sequence number")
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 65536, self._DEF_MASTER_SHA, "invalid message sequence number")
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, None, self._DEF_MASTER_SHA, "invalid message sequence number")
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, "", self._DEF_MASTER_SHA, "invalid message sequence number")
|
|
|
|
def testMasterSHA(self):
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 19, "invalid SHA1 data hash")
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 21, "invalid SHA1 data hash")
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, None, "invalid SHA1 data hash")
|
|
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, 1234, "invalid SHA1 data hash")
|
|
|
|
def _testNewFromDataFail(self, addr, data, fail_msg):
|
|
try:
|
|
seg = SegmentBase.new_from_data(addr, data)
|
|
except ValueError, exc:
|
|
pass
|
|
else:
|
|
self.fail("expected a ValueError about %s." % fail_msg)
|
|
|
|
def testNewFromDataAddress(self):
|
|
self._testNewFromDataFail(None, None, "bad address")
|
|
self._testNewFromDataFail('', None, "bad address")
|
|
self._testNewFromDataFail((''), None, "bad address")
|
|
self._testNewFromDataFail((1), None, "bad address")
|
|
self._testNewFromDataFail(('', ''), None, "bad address")
|
|
self._testNewFromDataFail((1, 3333), None, "bad address")
|
|
self._testNewFromDataFail(('', 0), None, "bad address")
|
|
self._testNewFromDataFail(('', 65536), None, "bad address")
|
|
|
|
def testNewFromDataData(self):
|
|
"""Only test generic new_from_data() bits, not type-specific ones."""
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, None, "invalid data")
|
|
|
|
really_short_data = "111"
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, really_short_data, "data too short")
|
|
|
|
only_header_data = "1" * SegmentBase.header_len()
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, only_header_data, "data too short")
|
|
|
|
too_much_data = "1" * (_UDP_DATAGRAM_SIZE + 1)
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, too_much_data, "too much data")
|
|
|
|
header_template = SegmentBase.header_template()
|
|
bad_magic_data = struct.pack(header_template, 0x12345678, self._DEF_SEG_TYPE,
|
|
self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, bad_magic_data, "invalid magic")
|
|
|
|
bad_type_data = struct.pack(header_template, self._SEG_MAGIC, -1, self._DEF_SEGNO,
|
|
self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, bad_type_data, "invalid segment type")
|
|
|
|
# Test master_sha that doesn't match data's SHA
|
|
header = struct.pack(header_template, self._SEG_MAGIC, self._DEF_SEG_TYPE, 1, 1,
|
|
self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
|
data = struct.pack("! 15s", "7" * 15)
|
|
self._testNewFromDataFail(self._DEF_ADDRESS, header + data, "single-segment message SHA mismatch")
|
|
|
|
def addToSuite(suite):
|
|
suite.addTest(SegmentBaseInitTestCase("testGoodInit"))
|
|
suite.addTest(SegmentBaseInitTestCase("testSegmentNumber"))
|
|
suite.addTest(SegmentBaseInitTestCase("testTotalMessageSegmentNumber"))
|
|
suite.addTest(SegmentBaseInitTestCase("testMessageSequenceNumber"))
|
|
suite.addTest(SegmentBaseInitTestCase("testMasterSHA"))
|
|
suite.addTest(SegmentBaseInitTestCase("testNewFromDataAddress"))
|
|
suite.addTest(SegmentBaseInitTestCase("testNewFromDataData"))
|
|
addToSuite = staticmethod(addToSuite)
|
|
|
|
|
|
class DataSegmentTestCase(SegmentBaseTestCase):
|
|
"""Test DataSegment class specific initialization and stuff."""
|
|
|
|
def testInit(self):
|
|
seg = DataSegment(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA)
|
|
assert seg.segment_type() == SegmentBase.type_data(), "Segment wasn't a data segment."
|
|
|
|
def testNewFromParts(self):
|
|
try:
|
|
seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS,
|
|
self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, None)
|
|
except ValueError, exc:
|
|
pass
|
|
else:
|
|
self.fail("Expected ValueError about invalid data.")
|
|
|
|
# Ensure message data is same as we stuff in after object is instantiated
|
|
payload = "How are you today?"
|
|
seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, payload)
|
|
assert seg.data() == payload, "Data after segment creation didn't match expected."
|
|
|
|
def testNewFromData(self):
|
|
"""Test DataSegment's new_from_data() functionality."""
|
|
|
|
# Make sure something valid actually works
|
|
header_template = SegmentBase.header_template()
|
|
payload_str = "How are you today?"
|
|
payload = struct.pack("! %ds" % len(payload_str), payload_str)
|
|
payload_sha = _sha_data(payload)
|
|
header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_data(), self._DEF_SEGNO,
|
|
self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, payload_sha)
|
|
seg = SegmentBase.new_from_data(self._DEF_ADDRESS, header + payload)
|
|
|
|
assert seg.address() == self._DEF_ADDRESS, "Segment address did not match expected."
|
|
assert seg.segment_type() == SegmentBase.type_data(), "Segment type did not match expected."
|
|
assert seg.segment_number() == self._DEF_SEGNO, "Segment number did not match expected."
|
|
assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments did not match expected."
|
|
assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number did not match expected."
|
|
assert seg.master_sha() == payload_sha, "Message master SHA did not match expected."
|
|
assert seg.data() == payload, "Segment data did not match expected payload."
|
|
|
|
def addToSuite(suite):
|
|
suite.addTest(DataSegmentTestCase("testInit"))
|
|
suite.addTest(DataSegmentTestCase("testNewFromParts"))
|
|
suite.addTest(DataSegmentTestCase("testNewFromData"))
|
|
addToSuite = staticmethod(addToSuite)
|
|
|
|
|
|
class RetransmitSegmentTestCase(SegmentBaseTestCase):
|
|
"""Test RetransmitSegment class specific initialization and stuff."""
|
|
|
|
def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg):
|
|
try:
|
|
seg = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
|
|
except ValueError, exc:
|
|
pass
|
|
else:
|
|
self.fail("expected a ValueError for %s." % fail_msg)
|
|
|
|
def testInit(self):
|
|
self._test_init_fail(0, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
|
self._test_init_fail(2, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
|
self._test_init_fail(1, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid number of total segments")
|
|
self._test_init_fail(1, 2, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid number of total segments")
|
|
|
|
# Something that's supposed to work
|
|
seg = RetransmitSegment(1, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
|
assert seg.segment_type() == SegmentBase.type_retransmit(), "Segment wasn't a retransmit segment."
|
|
|
|
def _test_new_from_parts_fail(self, msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number, fail_msg):
|
|
try:
|
|
seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, msg_seq_num, rt_msg_seq_num,
|
|
rt_master_sha, rt_segment_number)
|
|
except ValueError, exc:
|
|
pass
|
|
else:
|
|
self.fail("expected a ValueError for %s." % fail_msg)
|
|
|
|
def testNewFromParts(self):
|
|
"""Test RetransmitSegment's new_from_parts() functionality."""
|
|
self._test_new_from_parts_fail(0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid message sequence number")
|
|
self._test_new_from_parts_fail(65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid message sequence number")
|
|
self._test_new_from_parts_fail(None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid message sequence number")
|
|
self._test_new_from_parts_fail("", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid message sequence number")
|
|
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, 0, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid retransmit message sequence number")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, 65536, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid retransmit message sequence number")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, None, self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid retransmit message sequence number")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, "", self._DEF_MASTER_SHA,
|
|
self._DEF_SEGNO, "invalid retransmit message sequence number")
|
|
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, "1" * 19,
|
|
self._DEF_SEGNO, "invalid retransmit message master SHA")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, "1" * 21,
|
|
self._DEF_SEGNO, "invalid retransmit message master SHA")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, None,
|
|
self._DEF_SEGNO, "invalid retransmit message master SHA")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, 1234,
|
|
self._DEF_SEGNO, "invalid retransmit message master SHA")
|
|
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, 0, "invalid retransmit message segment number")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, 65536, "invalid retransmit message segment number")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, None, "invalid retransmit message segment number")
|
|
self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, "", "invalid retransmit message segment number")
|
|
|
|
# Ensure message data is same as we stuff in after object is instantiated
|
|
seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, self._DEF_SEGNO)
|
|
assert seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "RT message sequence number after segment creation didn't match expected."
|
|
assert seg.rt_master_sha() == self._DEF_MASTER_SHA, "RT master SHA after segment creation didn't match expected."
|
|
assert seg.rt_segment_number() == self._DEF_SEGNO, "RT segment number after segment creation didn't match expected."
|
|
|
|
def _new_from_data(self, rt_msg_seq_num, rt_master_sha, rt_segment_number):
|
|
payload = struct.pack(RetransmitSegment.data_template(), rt_msg_seq_num, rt_master_sha, rt_segment_number)
|
|
payload_sha = _sha_data(payload)
|
|
header_template = SegmentBase.header_template()
|
|
header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_retransmit(), 1, 1,
|
|
self._DEF_MSG_SEQ_NUM, payload_sha)
|
|
return header + payload
|
|
|
|
def _test_new_from_data_fail(self, rt_msg_seq_num, rt_master_sha, rt_segment_number, fail_msg):
|
|
try:
|
|
packet = self._new_from_data(rt_msg_seq_num, rt_master_sha, rt_segment_number)
|
|
seg = SegmentBase.new_from_data(self._DEF_ADDRESS, packet)
|
|
except ValueError, exc:
|
|
pass
|
|
else:
|
|
self.fail("Expected a ValueError about %s." % fail_msg)
|
|
|
|
def testNewFromData(self):
|
|
"""Test DataSegment's new_from_data() functionality."""
|
|
self._test_new_from_data_fail(0, self._DEF_MASTER_SHA, self._DEF_SEGNO, "invalid RT message sequence number")
|
|
self._test_new_from_data_fail(65536, self._DEF_MASTER_SHA, self._DEF_SEGNO, "invalid RT message sequence number")
|
|
|
|
self._test_new_from_data_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, 0, "invalid RT segment number")
|
|
self._test_new_from_data_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, 65536, "invalid RT segment number")
|
|
|
|
# Ensure something that should work
|
|
packet = self._new_from_data(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, self._DEF_SEGNO)
|
|
seg = SegmentBase.new_from_data(self._DEF_ADDRESS, packet)
|
|
assert seg.segment_type() == SegmentBase.type_retransmit(), "Segment wasn't expected type."
|
|
assert seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "Segment RT message sequence number didn't match expected."
|
|
assert seg.rt_master_sha() == self._DEF_MASTER_SHA, "Segment RT master SHA didn't match expected."
|
|
assert seg.rt_segment_number() == self._DEF_SEGNO, "Segment RT segment number didn't match expected."
|
|
|
|
def testPartsToData(self):
|
|
seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM,
|
|
self._DEF_MASTER_SHA, self._DEF_SEGNO)
|
|
new_seg = SegmentBase.new_from_data(self._DEF_ADDRESS, seg.packetize())
|
|
assert new_seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "Segment RT message sequence number didn't match expected."
|
|
assert new_seg.rt_master_sha() == self._DEF_MASTER_SHA, "Segment RT master SHA didn't match expected."
|
|
assert new_seg.rt_segment_number() == self._DEF_SEGNO, "Segment RT segment number didn't match expected."
|
|
|
|
def addToSuite(suite):
|
|
suite.addTest(RetransmitSegmentTestCase("testInit"))
|
|
suite.addTest(RetransmitSegmentTestCase("testNewFromParts"))
|
|
suite.addTest(RetransmitSegmentTestCase("testNewFromData"))
|
|
suite.addTest(RetransmitSegmentTestCase("testPartsToData"))
|
|
addToSuite = staticmethod(addToSuite)
|
|
|
|
|
|
class SHAUtilsTestCase(unittest.TestCase):
|
|
def testSHA(self):
|
|
data = "235jklqt3hjwasdv879wfe89723rqjh32tr3hwaejksdvd89udsv89dsgiougjktqjhk23tjht23hjt3qhjewagthjasgdgsd"
|
|
data_sha = _sha_data(data)
|
|
assert len(data_sha) == 20, "SHA wasn't correct size."
|
|
known_sha = "\xee\x9e\xb9\x1d\xe8\x96\x75\xcb\x12\xf1\x25\x22\x0f\x76\xf7\xf3\xc8\x4e\xbf\xcd"
|
|
assert data_sha == known_sha, "SHA didn't match known SHA."
|
|
|
|
def testStringifySHA(self):
|
|
data = "jlkwjlkaegdjlksgdjklsdgajklganjtwn23n325n23tjwgeajkga nafDA fwqnjlqtjkl23tjk2365jlk235jkl2356jlktjkltewjlktewjklewtjklaggsda"
|
|
data_known_sha = "9650c23db78092a0ffda4577c87ebf36d25c868e"
|
|
assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA."
|
|
# Do it twice for kicks
|
|
assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA."
|
|
|
|
def addToSuite(suite):
|
|
suite.addTest(SHAUtilsTestCase("testSHA"))
|
|
suite.addTest(SHAUtilsTestCase("testStringifySHA"))
|
|
addToSuite = staticmethod(addToSuite)
|
|
|
|
|
|
|
|
def unit_test():
|
|
suite = unittest.TestSuite()
|
|
SegmentBaseInitTestCase.addToSuite(suite)
|
|
DataSegmentTestCase.addToSuite(suite)
|
|
RetransmitSegmentTestCase.addToSuite(suite)
|
|
SHAUtilsTestCase.addToSuite(suite)
|
|
|
|
runner = unittest.TextTestRunner()
|
|
runner.run(suite)
|
|
|
|
|
|
|
|
def got_data(addr, data, user_data=None):
|
|
print "Got data from %s, writing to %s." % (addr, user_data)
|
|
fl = open(user_data, "w+")
|
|
fl.write(data)
|
|
fl.close()
|
|
|
|
def simple_test():
|
|
import sys
|
|
pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data, sys.argv[2])
|
|
# pipe.set_drop_probability(4)
|
|
pipe.start()
|
|
fl = open(sys.argv[1], "r")
|
|
data = fl.read()
|
|
fl.close()
|
|
msg = """The said Eliza, John, and Georgiana were now clustered round their mama in the drawing-room:
|
|
she lay reclined on a sofa by the fireside, and with her darlings about her (for the time neither
|
|
quarrelling nor crying) looked perfectly happy. Me, she had dispensed from joining the group; saying,
|
|
'She regretted to be under the necessity of keeping me at a distance; but that until she heard from
|
|
Bessie, and could discover by her own observation, that I was endeavouring in good earnest to acquire
|
|
a more sociable and childlike disposition, a more attractive and sprightly manner -- something lighter,
|
|
franker, more natural, as it were -- she really must exclude me from privileges intended only for
|
|
contented, happy, little children.'"""
|
|
pipe.send(data)
|
|
try:
|
|
gtk.main()
|
|
except KeyboardInterrupt:
|
|
print 'Ctrl+C pressed, exiting...'
|
|
|
|
|
|
|
|
def net_test_got_data(addr, data, user_data=None):
|
|
# Don't report data if we are a sender
|
|
if user_data:
|
|
return
|
|
print "%s (%s)" % (data, addr)
|
|
|
|
idstamp = 0
|
|
def transmit_data(pipe):
|
|
global idstamp
|
|
msg = "Message #%d" % idstamp
|
|
print "Sending '%s'" % msg
|
|
pipe.send(msg)
|
|
idstamp = idstamp + 1
|
|
return True
|
|
|
|
def network_test():
|
|
import sys, os
|
|
send = False
|
|
if len(sys.argv) != 2:
|
|
print "Need one arg, either 'send' or 'recv'"
|
|
os._exit(1)
|
|
if sys.argv[1] == "send":
|
|
send = True
|
|
elif sys.argv[1] == "recv":
|
|
send = False
|
|
else:
|
|
print "Arg should be either 'send' or 'recv'"
|
|
os._exit(1)
|
|
|
|
pipe = MostlyReliablePipe('', '224.0.0.222', 2293, net_test_got_data, send)
|
|
pipe.start()
|
|
if send:
|
|
gobject.timeout_add(1000, transmit_data, pipe)
|
|
try:
|
|
gtk.main()
|
|
except KeyboardInterrupt:
|
|
print 'Ctrl+C pressed, exiting...'
|
|
|
|
|
|
def main():
|
|
# unit_test()
|
|
# simple_test()
|
|
network_test()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|