Add a bunch of testcases
This commit is contained in:
parent
f751407d50
commit
11d54d71a9
@ -22,6 +22,7 @@ def _sha_data(data):
|
||||
sha_hash.update(data)
|
||||
return sha_hash.digest()
|
||||
|
||||
_UDP_DATAGRAM_SIZE = 512
|
||||
|
||||
class SegmentBase(object):
|
||||
_MAGIC = 0xbaea4304
|
||||
@ -34,12 +35,20 @@ class SegmentBase(object):
|
||||
#20: total data sha1
|
||||
_HEADER_TEMPLATE = "! IbHHH20s"
|
||||
_HEADER_LEN = struct.calcsize(_HEADER_TEMPLATE)
|
||||
_MTU = 512 - _HEADER_LEN
|
||||
_MTU = _UDP_DATAGRAM_SIZE - _HEADER_LEN
|
||||
|
||||
# Message segment packet types
|
||||
_SEGMENT_TYPE_DATA = 0
|
||||
_SEGMENT_TYPE_RETRANSMIT = 1
|
||||
|
||||
def magic():
|
||||
return SegmentBase._MAGIC
|
||||
magic = staticmethod(magic)
|
||||
|
||||
def header_template():
|
||||
return SegmentBase._HEADER_TEMPLATE
|
||||
header_template = staticmethod(header_template)
|
||||
|
||||
def type_data():
|
||||
return SegmentBase._SEGMENT_TYPE_DATA
|
||||
type_data = staticmethod(type_data)
|
||||
@ -69,18 +78,22 @@ class SegmentBase(object):
|
||||
self._addr = None
|
||||
|
||||
# Sanity checks on the message attributes
|
||||
if segno > 65535:
|
||||
raise ValueError("Segment number cannot be more than 65535.")
|
||||
if segno < 1:
|
||||
raise ValueError("Segment number must be greater than 0.")
|
||||
if not segno or type(segno) != type(1):
|
||||
raise ValueError("Segment number must be in integer.")
|
||||
if segno < 1 or segno > 65535:
|
||||
raise ValueError("Segment number must be between 1 and 65535 inclusive.")
|
||||
if not total_segs or type(total_segs) != type(1):
|
||||
raise ValueError("Message segment total must be an integer.")
|
||||
if total_segs < 1 or total_segs > 65535:
|
||||
raise ValueError("Message must have between 1 and 65535 segments inclusive.")
|
||||
if segno > total_segs:
|
||||
raise ValueError("Segment number cannot be larger than message segment total.")
|
||||
if total_segs < 1:
|
||||
raise ValueError("Message must have at least one segment.")
|
||||
if total_segs > 65535:
|
||||
raise ValueError("Message cannot have more than 65535 segments.")
|
||||
if msg_seq_num < 1:
|
||||
raise ValueError("Message sequence number must be greater than 0.")
|
||||
if not msg_seq_num or type(msg_seq_num) != type(1):
|
||||
raise ValueError("Message sequnce number must be an integer.")
|
||||
if msg_seq_num < 1 or msg_seq_num > 65535:
|
||||
raise ValueError("Message sequence number must be between 1 and 65535 inclusive.")
|
||||
if not master_sha or type(master_sha) != type("") or len(master_sha) != 20:
|
||||
raise ValueError("Message SHA1 checksum invalid.")
|
||||
|
||||
self._segno = segno
|
||||
self._total_segs = total_segs
|
||||
@ -90,15 +103,26 @@ class SegmentBase(object):
|
||||
def new_from_data(addr, data):
|
||||
"""Static constructor for creation from a packed data stream."""
|
||||
|
||||
if not addr or type(addr) != type(()):
|
||||
raise ValueError("Address must be a tuple.")
|
||||
if len(addr) != 2 or type(addr[0]) != type("") or type(addr[1]) != type(1):
|
||||
raise ValueError("Address format was invalid.")
|
||||
if addr[1] < 1 or addr[1] > 65535:
|
||||
raise ValueError("Address port was invalid.")
|
||||
|
||||
# Verify minimum length
|
||||
if len(data) < SegmentBase.header_len() + 1:
|
||||
if not data:
|
||||
raise ValueError("Segment data is invalid.")
|
||||
data_len = len(data)
|
||||
if data_len < SegmentBase.header_len() + 1:
|
||||
raise ValueError("Segment is less then minimum required length")
|
||||
if data_len > _UDP_DATAGRAM_SIZE:
|
||||
raise ValueError("Segment data is larger than allowed.")
|
||||
stream = StringIO.StringIO(data)
|
||||
|
||||
# Determine and verify the length of included data
|
||||
stream.seek(0, 2)
|
||||
header_size = struct.calcsize(SegmentBase._HEADER_TEMPLATE)
|
||||
data_len = stream.tell() - header_size
|
||||
data_len = stream.tell() - SegmentBase._HEADER_LEN
|
||||
stream.seek(0)
|
||||
|
||||
if data_len < 1:
|
||||
@ -108,12 +132,19 @@ class SegmentBase(object):
|
||||
|
||||
# Read the first header attributes
|
||||
(magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(SegmentBase._HEADER_TEMPLATE,
|
||||
stream.read(header_size))
|
||||
stream.read(SegmentBase._HEADER_LEN))
|
||||
|
||||
# Sanity checks on the message attributes
|
||||
if magic != SegmentBase._MAGIC:
|
||||
raise ValueError("Segment does not have the correct magic.")
|
||||
|
||||
# if the segment is the only one in the message, validate the data
|
||||
if segno == 1 and total_segs == 1:
|
||||
data_sha = _sha_data(stream.read(data_len))
|
||||
if data_sha != master_sha:
|
||||
raise ValueError("Single segment message SHA checksums didn't match.")
|
||||
stream.seek(SegmentBase._HEADER_LEN)
|
||||
|
||||
if seg_type == SegmentBase._SEGMENT_TYPE_DATA:
|
||||
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
|
||||
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
|
||||
@ -204,9 +235,10 @@ class RetransmitSegment(SegmentBase):
|
||||
# 2: message sequence number
|
||||
# 20: total data sha1
|
||||
# 2: segment number
|
||||
_RT_DATA_TEMPLATE = "@ H20sH"
|
||||
_RT_DATA_TEMPLATE = "! H20sH"
|
||||
|
||||
def __init__(self, segno, total_segs, msg_seq_num, master_sha):
|
||||
"""Should not be called directly."""
|
||||
if segno != 1 or total_segs != 1:
|
||||
raise ValueError("Retransmission request messages must have only one segment.")
|
||||
|
||||
@ -239,6 +271,8 @@ class RetransmitSegment(SegmentBase):
|
||||
self._data_len = data_len
|
||||
(rt_msg_seq_num, rt_master_sha, rt_seg_no) = struct.unpack(self._RT_DATA_TEMPLATE,
|
||||
stream.read(self._data_len))
|
||||
self._data = struct.pack(self._RT_DATA_TEMPLATE, rt_msg_seq_num,
|
||||
rt_master_sha, rt_seg_no)
|
||||
self._rt_msg_seq_num = rt_msg_seq_num
|
||||
self._rt_master_sha = rt_master_sha
|
||||
self._rt_segment_number = rt_seg_no
|
||||
@ -473,10 +507,220 @@ class MostlyReliablePipe(object):
|
||||
return False
|
||||
|
||||
|
||||
#################################################################
|
||||
# Tests
|
||||
#################################################################
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class SegmentBaseTestCase(unittest.TestCase):
|
||||
_DEF_SEGNO = 1
|
||||
_DEF_TOT_SEGS = 5
|
||||
_DEF_MSG_SEQ_NUM = 4556
|
||||
_DEF_MASTER_SHA = "12345678901234567890"
|
||||
_DEF_SEG_TYPE = 0
|
||||
|
||||
_DEF_ADDRESS = ('123.3.2.1', 3333)
|
||||
_SEG_MAGIC = 0xbaea4304
|
||||
|
||||
|
||||
class SegmentBaseInitTestCase(SegmentBaseTestCase):
|
||||
def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg):
|
||||
try:
|
||||
seg = SegmentBase(segno, total_segs, msg_seq_num, master_sha)
|
||||
except ValueError, exc:
|
||||
pass
|
||||
else:
|
||||
self.fail("expected a ValueError for %s." % fail_msg)
|
||||
|
||||
def testSegmentBase(self):
|
||||
assert SegmentBase.magic() == self._SEG_MAGIC, "Segment magic wasn't correct!"
|
||||
assert SegmentBase.header_len() > 0, "header size was not greater than zero."
|
||||
assert SegmentBase.mtu() > 0, "MTU was not greater than zero."
|
||||
assert SegmentBase.mtu() + SegmentBase.header_len() == _UDP_DATAGRAM_SIZE, "MTU + header size didn't equal expected %d." % _UDP_DATAGRAM_SIZE
|
||||
|
||||
def testGoodInit(self):
|
||||
seg = SegmentBase(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
||||
assert seg.stime() < time.time(), "segment start time is less than now!"
|
||||
assert not seg.addr(), "Segment address was not None after init."
|
||||
assert seg.segment_number() == self._DEF_SEGNO, "Segment number wasn't correct after init."
|
||||
assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments wasn't correct after init."
|
||||
assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number wasn't correct after init."
|
||||
assert seg.master_sha() == self._DEF_MASTER_SHA, "Message master SHA wasn't correct after init."
|
||||
assert seg.segment_type() == None, "Segment type was not None after init."
|
||||
assert seg.transmits() == 0, "Segment transmits was not 0 after init."
|
||||
assert seg.last_transmit() == 0, "Segment last transmit was not 0 after init."
|
||||
assert seg.data() == None, "Segment data was not None after init."
|
||||
|
||||
def testSegmentNumber(self):
|
||||
self._test_init_fail(0, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
||||
self._test_init_fail(65536, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
||||
self._test_init_fail(None, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
||||
self._test_init_fail("", self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number")
|
||||
|
||||
def testTotalMessageSegmentNumber(self):
|
||||
self._test_init_fail(self._DEF_SEGNO, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
||||
self._test_init_fail(self._DEF_SEGNO, 65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
||||
self._test_init_fail(self._DEF_SEGNO, None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
||||
self._test_init_fail(self._DEF_SEGNO, "", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments")
|
||||
|
||||
def testMessageSequenceNumber(self):
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 0, self._DEF_MASTER_SHA, "invalid message sequence number")
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 65536, self._DEF_MASTER_SHA, "invalid message sequence number")
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, None, self._DEF_MASTER_SHA, "invalid message sequence number")
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, "", self._DEF_MASTER_SHA, "invalid message sequence number")
|
||||
|
||||
def testMasterSHA(self):
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 19, "invalid SHA1 data hash")
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 21, "invalid SHA1 data hash")
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, None, "invalid SHA1 data hash")
|
||||
self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, 1234, "invalid SHA1 data hash")
|
||||
|
||||
def _testNewFromDataFail(self, addr, data, fail_msg):
|
||||
try:
|
||||
seg = SegmentBase.new_from_data(addr, data)
|
||||
except ValueError, exc:
|
||||
pass
|
||||
else:
|
||||
self.fail("expected a ValueError about %s." % fail_msg)
|
||||
|
||||
def testNewFromDataAddress(self):
|
||||
self._testNewFromDataFail(None, None, "bad address")
|
||||
self._testNewFromDataFail('', None, "bad address")
|
||||
self._testNewFromDataFail((''), None, "bad address")
|
||||
self._testNewFromDataFail((1), None, "bad address")
|
||||
self._testNewFromDataFail(('', ''), None, "bad address")
|
||||
self._testNewFromDataFail((1, 3333), None, "bad address")
|
||||
self._testNewFromDataFail(('', 0), None, "bad address")
|
||||
self._testNewFromDataFail(('', 65536), None, "bad address")
|
||||
|
||||
def testNewFromDataData(self):
|
||||
"""Only test generic new_from_data() bits, not type-specific ones."""
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, None, "invalid data")
|
||||
|
||||
really_short_data = "111"
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, really_short_data, "data too short")
|
||||
|
||||
only_header_data = "1" * SegmentBase.header_len()
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, only_header_data, "data too short")
|
||||
|
||||
too_much_data = "1" * (_UDP_DATAGRAM_SIZE + 1)
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, too_much_data, "too much data")
|
||||
|
||||
header_template = SegmentBase.header_template()
|
||||
bad_magic_data = struct.pack(header_template, 0x12345678, self._DEF_SEG_TYPE,
|
||||
self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, bad_magic_data, "invalid magic")
|
||||
|
||||
bad_type_data = struct.pack(header_template, self._SEG_MAGIC, -1, self._DEF_SEGNO,
|
||||
self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, bad_type_data, "invalid segment type")
|
||||
|
||||
# Test master_sha that doesn't match data's SHA
|
||||
header = struct.pack(header_template, self._SEG_MAGIC, self._DEF_SEG_TYPE, 1, 1,
|
||||
self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA)
|
||||
data = struct.pack("! 15s", "7" * 15)
|
||||
self._testNewFromDataFail(self._DEF_ADDRESS, header + data, "single-segment message SHA mismatch")
|
||||
|
||||
def addToSuite(suite):
|
||||
suite.addTest(SegmentBaseInitTestCase("testGoodInit"))
|
||||
suite.addTest(SegmentBaseInitTestCase("testSegmentNumber"))
|
||||
suite.addTest(SegmentBaseInitTestCase("testTotalMessageSegmentNumber"))
|
||||
suite.addTest(SegmentBaseInitTestCase("testMessageSequenceNumber"))
|
||||
suite.addTest(SegmentBaseInitTestCase("testMasterSHA"))
|
||||
suite.addTest(SegmentBaseInitTestCase("testNewFromDataAddress"))
|
||||
suite.addTest(SegmentBaseInitTestCase("testNewFromDataData"))
|
||||
addToSuite = staticmethod(addToSuite)
|
||||
|
||||
|
||||
class DataSegmentTestCase(SegmentBaseTestCase):
|
||||
"""Test DataSegment class specific initialization and stuff."""
|
||||
|
||||
def testInit(self):
|
||||
seg = DataSegment(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM,
|
||||
self._DEF_MASTER_SHA)
|
||||
assert seg.segment_type() == SegmentBase.type_data(), "Segment wasn't a data segment."
|
||||
|
||||
def testNewFromParts(self):
|
||||
try:
|
||||
seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS,
|
||||
self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, None)
|
||||
except ValueError, exc:
|
||||
pass
|
||||
else:
|
||||
self.fail("Expected ValueError about invalid data.")
|
||||
|
||||
# Ensure message data is same as we stuff in after object is instantiated
|
||||
payload = "How are you today?"
|
||||
seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM,
|
||||
self._DEF_MASTER_SHA, payload)
|
||||
assert seg.data() == payload, "Data after segment creation didn't match expected."
|
||||
|
||||
def testNewFromData(self):
|
||||
"""Test DataSegment's new_from_data() functionality."""
|
||||
|
||||
# Make sure something valid actually works
|
||||
header_template = SegmentBase.header_template()
|
||||
payload_str = "How are you today?"
|
||||
payload = struct.pack("! %ds" % len(payload_str), payload_str)
|
||||
payload_sha = _sha_data(payload)
|
||||
header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_data(), self._DEF_SEGNO,
|
||||
self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, payload_sha)
|
||||
seg = SegmentBase.new_from_data(self._DEF_ADDRESS, header + payload)
|
||||
|
||||
assert seg.addr() == self._DEF_ADDRESS, "Segment address did not match expected."
|
||||
assert seg.segment_type() == SegmentBase.type_data(), "Segment type did not match expected."
|
||||
assert seg.segment_number() == self._DEF_SEGNO, "Segment number did not match expected."
|
||||
assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments did not match expected."
|
||||
assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number did not match expected."
|
||||
assert seg.master_sha() == payload_sha, "Message master SHA did not match expected."
|
||||
assert seg.data() == payload, "Segment data did not match expected payload."
|
||||
|
||||
def addToSuite(suite):
|
||||
suite.addTest(DataSegmentTestCase("testInit"))
|
||||
suite.addTest(DataSegmentTestCase("testNewFromParts"))
|
||||
suite.addTest(DataSegmentTestCase("testNewFromData"))
|
||||
addToSuite = staticmethod(addToSuite)
|
||||
|
||||
|
||||
class SHAUtilsTestCase(unittest.TestCase):
|
||||
def testSHA(self):
|
||||
data = "235jklqt3hjwasdv879wfe89723rqjh32tr3hwaejksdvd89udsv89dsgiougjktqjhk23tjht23hjt3qhjewagthjasgdgsd"
|
||||
data_sha = _sha_data(data)
|
||||
assert len(data_sha) == 20, "SHA wasn't correct size."
|
||||
known_sha = "\xee\x9e\xb9\x1d\xe8\x96\x75\xcb\x12\xf1\x25\x22\x0f\x76\xf7\xf3\xc8\x4e\xbf\xcd"
|
||||
assert data_sha == known_sha, "SHA didn't match known SHA."
|
||||
|
||||
def testStringifySHA(self):
|
||||
data = "jlkwjlkaegdjlksgdjklsdgajklganjtwn23n325n23tjwgeajkga nafDA fwqnjlqtjkl23tjk2365jlk235jkl2356jlktjkltewjlktewjklewtjklaggsda"
|
||||
data_known_sha = "9650c23db78092a0ffda4577c87ebf36d25c868e"
|
||||
assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA."
|
||||
# Do it twice for kicks
|
||||
assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA."
|
||||
|
||||
def addToSuite(suite):
|
||||
suite.addTest(SHAUtilsTestCase("testSHA"))
|
||||
suite.addTest(SHAUtilsTestCase("testStringifySHA"))
|
||||
addToSuite = staticmethod(addToSuite)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
suite = unittest.TestSuite()
|
||||
SegmentBaseInitTestCase.addToSuite(suite)
|
||||
DataSegmentTestCase.addToSuite(suite)
|
||||
SHAUtilsTestCase.addToSuite(suite)
|
||||
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite)
|
||||
|
||||
|
||||
|
||||
def got_data(addr, data, user_data=None):
|
||||
print "Data (%s): %s" % (addr, data)
|
||||
|
||||
def main():
|
||||
def foobar():
|
||||
pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data)
|
||||
pipe.start()
|
||||
pipe.send('The quick brown fox jumps over the lazy dog')
|
||||
|
Loading…
Reference in New Issue
Block a user