From 11d54d71a97e72bf0fdeac389cb5c16532be9f0b Mon Sep 17 00:00:00 2001 From: Dan Williams Date: Wed, 17 May 2006 16:23:35 -0400 Subject: [PATCH] Add a bunch of testcases --- sugar/p2p/MostlyReliablePipe.py | 278 ++++++++++++++++++++++++++++++-- 1 file changed, 261 insertions(+), 17 deletions(-) diff --git a/sugar/p2p/MostlyReliablePipe.py b/sugar/p2p/MostlyReliablePipe.py index 9009d1a3..ae17c441 100644 --- a/sugar/p2p/MostlyReliablePipe.py +++ b/sugar/p2p/MostlyReliablePipe.py @@ -22,6 +22,7 @@ def _sha_data(data): sha_hash.update(data) return sha_hash.digest() +_UDP_DATAGRAM_SIZE = 512 class SegmentBase(object): _MAGIC = 0xbaea4304 @@ -34,12 +35,20 @@ class SegmentBase(object): #20: total data sha1 _HEADER_TEMPLATE = "! IbHHH20s" _HEADER_LEN = struct.calcsize(_HEADER_TEMPLATE) - _MTU = 512 - _HEADER_LEN + _MTU = _UDP_DATAGRAM_SIZE - _HEADER_LEN # Message segment packet types _SEGMENT_TYPE_DATA = 0 _SEGMENT_TYPE_RETRANSMIT = 1 + def magic(): + return SegmentBase._MAGIC + magic = staticmethod(magic) + + def header_template(): + return SegmentBase._HEADER_TEMPLATE + header_template = staticmethod(header_template) + def type_data(): return SegmentBase._SEGMENT_TYPE_DATA type_data = staticmethod(type_data) @@ -69,18 +78,22 @@ class SegmentBase(object): self._addr = None # Sanity checks on the message attributes - if segno > 65535: - raise ValueError("Segment number cannot be more than 65535.") - if segno < 1: - raise ValueError("Segment number must be greater than 0.") + if not segno or type(segno) != type(1): + raise ValueError("Segment number must be in integer.") + if segno < 1 or segno > 65535: + raise ValueError("Segment number must be between 1 and 65535 inclusive.") + if not total_segs or type(total_segs) != type(1): + raise ValueError("Message segment total must be an integer.") + if total_segs < 1 or total_segs > 65535: + raise ValueError("Message must have between 1 and 65535 segments inclusive.") if segno > total_segs: raise ValueError("Segment number cannot be larger than message segment total.") - if total_segs < 1: - raise ValueError("Message must have at least one segment.") - if total_segs > 65535: - raise ValueError("Message cannot have more than 65535 segments.") - if msg_seq_num < 1: - raise ValueError("Message sequence number must be greater than 0.") + if not msg_seq_num or type(msg_seq_num) != type(1): + raise ValueError("Message sequnce number must be an integer.") + if msg_seq_num < 1 or msg_seq_num > 65535: + raise ValueError("Message sequence number must be between 1 and 65535 inclusive.") + if not master_sha or type(master_sha) != type("") or len(master_sha) != 20: + raise ValueError("Message SHA1 checksum invalid.") self._segno = segno self._total_segs = total_segs @@ -90,15 +103,26 @@ class SegmentBase(object): def new_from_data(addr, data): """Static constructor for creation from a packed data stream.""" + if not addr or type(addr) != type(()): + raise ValueError("Address must be a tuple.") + if len(addr) != 2 or type(addr[0]) != type("") or type(addr[1]) != type(1): + raise ValueError("Address format was invalid.") + if addr[1] < 1 or addr[1] > 65535: + raise ValueError("Address port was invalid.") + # Verify minimum length - if len(data) < SegmentBase.header_len() + 1: + if not data: + raise ValueError("Segment data is invalid.") + data_len = len(data) + if data_len < SegmentBase.header_len() + 1: raise ValueError("Segment is less then minimum required length") + if data_len > _UDP_DATAGRAM_SIZE: + raise ValueError("Segment data is larger than allowed.") stream = StringIO.StringIO(data) # Determine and verify the length of included data stream.seek(0, 2) - header_size = struct.calcsize(SegmentBase._HEADER_TEMPLATE) - data_len = stream.tell() - header_size + data_len = stream.tell() - SegmentBase._HEADER_LEN stream.seek(0) if data_len < 1: @@ -108,12 +132,19 @@ class SegmentBase(object): # Read the first header attributes (magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(SegmentBase._HEADER_TEMPLATE, - stream.read(header_size)) + stream.read(SegmentBase._HEADER_LEN)) # Sanity checks on the message attributes if magic != SegmentBase._MAGIC: raise ValueError("Segment does not have the correct magic.") + # if the segment is the only one in the message, validate the data + if segno == 1 and total_segs == 1: + data_sha = _sha_data(stream.read(data_len)) + if data_sha != master_sha: + raise ValueError("Single segment message SHA checksums didn't match.") + stream.seek(SegmentBase._HEADER_LEN) + if seg_type == SegmentBase._SEGMENT_TYPE_DATA: segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT: @@ -204,9 +235,10 @@ class RetransmitSegment(SegmentBase): # 2: message sequence number # 20: total data sha1 # 2: segment number - _RT_DATA_TEMPLATE = "@ H20sH" + _RT_DATA_TEMPLATE = "! H20sH" def __init__(self, segno, total_segs, msg_seq_num, master_sha): + """Should not be called directly.""" if segno != 1 or total_segs != 1: raise ValueError("Retransmission request messages must have only one segment.") @@ -239,6 +271,8 @@ class RetransmitSegment(SegmentBase): self._data_len = data_len (rt_msg_seq_num, rt_master_sha, rt_seg_no) = struct.unpack(self._RT_DATA_TEMPLATE, stream.read(self._data_len)) + self._data = struct.pack(self._RT_DATA_TEMPLATE, rt_msg_seq_num, + rt_master_sha, rt_seg_no) self._rt_msg_seq_num = rt_msg_seq_num self._rt_master_sha = rt_master_sha self._rt_segment_number = rt_seg_no @@ -473,10 +507,220 @@ class MostlyReliablePipe(object): return False +################################################################# +# Tests +################################################################# + +import unittest + + +class SegmentBaseTestCase(unittest.TestCase): + _DEF_SEGNO = 1 + _DEF_TOT_SEGS = 5 + _DEF_MSG_SEQ_NUM = 4556 + _DEF_MASTER_SHA = "12345678901234567890" + _DEF_SEG_TYPE = 0 + + _DEF_ADDRESS = ('123.3.2.1', 3333) + _SEG_MAGIC = 0xbaea4304 + + +class SegmentBaseInitTestCase(SegmentBaseTestCase): + def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg): + try: + seg = SegmentBase(segno, total_segs, msg_seq_num, master_sha) + except ValueError, exc: + pass + else: + self.fail("expected a ValueError for %s." % fail_msg) + + def testSegmentBase(self): + assert SegmentBase.magic() == self._SEG_MAGIC, "Segment magic wasn't correct!" + assert SegmentBase.header_len() > 0, "header size was not greater than zero." + assert SegmentBase.mtu() > 0, "MTU was not greater than zero." + assert SegmentBase.mtu() + SegmentBase.header_len() == _UDP_DATAGRAM_SIZE, "MTU + header size didn't equal expected %d." % _UDP_DATAGRAM_SIZE + + def testGoodInit(self): + seg = SegmentBase(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + assert seg.stime() < time.time(), "segment start time is less than now!" + assert not seg.addr(), "Segment address was not None after init." + assert seg.segment_number() == self._DEF_SEGNO, "Segment number wasn't correct after init." + assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments wasn't correct after init." + assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number wasn't correct after init." + assert seg.master_sha() == self._DEF_MASTER_SHA, "Message master SHA wasn't correct after init." + assert seg.segment_type() == None, "Segment type was not None after init." + assert seg.transmits() == 0, "Segment transmits was not 0 after init." + assert seg.last_transmit() == 0, "Segment last transmit was not 0 after init." + assert seg.data() == None, "Segment data was not None after init." + + def testSegmentNumber(self): + self._test_init_fail(0, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail(65536, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail(None, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail("", self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + + def testTotalMessageSegmentNumber(self): + self._test_init_fail(self._DEF_SEGNO, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + self._test_init_fail(self._DEF_SEGNO, 65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + self._test_init_fail(self._DEF_SEGNO, None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + self._test_init_fail(self._DEF_SEGNO, "", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + + def testMessageSequenceNumber(self): + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 0, self._DEF_MASTER_SHA, "invalid message sequence number") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 65536, self._DEF_MASTER_SHA, "invalid message sequence number") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, None, self._DEF_MASTER_SHA, "invalid message sequence number") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, "", self._DEF_MASTER_SHA, "invalid message sequence number") + + def testMasterSHA(self): + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 19, "invalid SHA1 data hash") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 21, "invalid SHA1 data hash") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, None, "invalid SHA1 data hash") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, 1234, "invalid SHA1 data hash") + + def _testNewFromDataFail(self, addr, data, fail_msg): + try: + seg = SegmentBase.new_from_data(addr, data) + except ValueError, exc: + pass + else: + self.fail("expected a ValueError about %s." % fail_msg) + + def testNewFromDataAddress(self): + self._testNewFromDataFail(None, None, "bad address") + self._testNewFromDataFail('', None, "bad address") + self._testNewFromDataFail((''), None, "bad address") + self._testNewFromDataFail((1), None, "bad address") + self._testNewFromDataFail(('', ''), None, "bad address") + self._testNewFromDataFail((1, 3333), None, "bad address") + self._testNewFromDataFail(('', 0), None, "bad address") + self._testNewFromDataFail(('', 65536), None, "bad address") + + def testNewFromDataData(self): + """Only test generic new_from_data() bits, not type-specific ones.""" + self._testNewFromDataFail(self._DEF_ADDRESS, None, "invalid data") + + really_short_data = "111" + self._testNewFromDataFail(self._DEF_ADDRESS, really_short_data, "data too short") + + only_header_data = "1" * SegmentBase.header_len() + self._testNewFromDataFail(self._DEF_ADDRESS, only_header_data, "data too short") + + too_much_data = "1" * (_UDP_DATAGRAM_SIZE + 1) + self._testNewFromDataFail(self._DEF_ADDRESS, too_much_data, "too much data") + + header_template = SegmentBase.header_template() + bad_magic_data = struct.pack(header_template, 0x12345678, self._DEF_SEG_TYPE, + self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + self._testNewFromDataFail(self._DEF_ADDRESS, bad_magic_data, "invalid magic") + + bad_type_data = struct.pack(header_template, self._SEG_MAGIC, -1, self._DEF_SEGNO, + self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + self._testNewFromDataFail(self._DEF_ADDRESS, bad_type_data, "invalid segment type") + + # Test master_sha that doesn't match data's SHA + header = struct.pack(header_template, self._SEG_MAGIC, self._DEF_SEG_TYPE, 1, 1, + self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + data = struct.pack("! 15s", "7" * 15) + self._testNewFromDataFail(self._DEF_ADDRESS, header + data, "single-segment message SHA mismatch") + + def addToSuite(suite): + suite.addTest(SegmentBaseInitTestCase("testGoodInit")) + suite.addTest(SegmentBaseInitTestCase("testSegmentNumber")) + suite.addTest(SegmentBaseInitTestCase("testTotalMessageSegmentNumber")) + suite.addTest(SegmentBaseInitTestCase("testMessageSequenceNumber")) + suite.addTest(SegmentBaseInitTestCase("testMasterSHA")) + suite.addTest(SegmentBaseInitTestCase("testNewFromDataAddress")) + suite.addTest(SegmentBaseInitTestCase("testNewFromDataData")) + addToSuite = staticmethod(addToSuite) + + +class DataSegmentTestCase(SegmentBaseTestCase): + """Test DataSegment class specific initialization and stuff.""" + + def testInit(self): + seg = DataSegment(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA) + assert seg.segment_type() == SegmentBase.type_data(), "Segment wasn't a data segment." + + def testNewFromParts(self): + try: + seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, + self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, None) + except ValueError, exc: + pass + else: + self.fail("Expected ValueError about invalid data.") + + # Ensure message data is same as we stuff in after object is instantiated + payload = "How are you today?" + seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, payload) + assert seg.data() == payload, "Data after segment creation didn't match expected." + + def testNewFromData(self): + """Test DataSegment's new_from_data() functionality.""" + + # Make sure something valid actually works + header_template = SegmentBase.header_template() + payload_str = "How are you today?" + payload = struct.pack("! %ds" % len(payload_str), payload_str) + payload_sha = _sha_data(payload) + header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_data(), self._DEF_SEGNO, + self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, payload_sha) + seg = SegmentBase.new_from_data(self._DEF_ADDRESS, header + payload) + + assert seg.addr() == self._DEF_ADDRESS, "Segment address did not match expected." + assert seg.segment_type() == SegmentBase.type_data(), "Segment type did not match expected." + assert seg.segment_number() == self._DEF_SEGNO, "Segment number did not match expected." + assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments did not match expected." + assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number did not match expected." + assert seg.master_sha() == payload_sha, "Message master SHA did not match expected." + assert seg.data() == payload, "Segment data did not match expected payload." + + def addToSuite(suite): + suite.addTest(DataSegmentTestCase("testInit")) + suite.addTest(DataSegmentTestCase("testNewFromParts")) + suite.addTest(DataSegmentTestCase("testNewFromData")) + addToSuite = staticmethod(addToSuite) + + +class SHAUtilsTestCase(unittest.TestCase): + def testSHA(self): + data = "235jklqt3hjwasdv879wfe89723rqjh32tr3hwaejksdvd89udsv89dsgiougjktqjhk23tjht23hjt3qhjewagthjasgdgsd" + data_sha = _sha_data(data) + assert len(data_sha) == 20, "SHA wasn't correct size." + known_sha = "\xee\x9e\xb9\x1d\xe8\x96\x75\xcb\x12\xf1\x25\x22\x0f\x76\xf7\xf3\xc8\x4e\xbf\xcd" + assert data_sha == known_sha, "SHA didn't match known SHA." + + def testStringifySHA(self): + data = "jlkwjlkaegdjlksgdjklsdgajklganjtwn23n325n23tjwgeajkga nafDA fwqnjlqtjkl23tjk2365jlk235jkl2356jlktjkltewjlktewjklewtjklaggsda" + data_known_sha = "9650c23db78092a0ffda4577c87ebf36d25c868e" + assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA." + # Do it twice for kicks + assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA." + + def addToSuite(suite): + suite.addTest(SHAUtilsTestCase("testSHA")) + suite.addTest(SHAUtilsTestCase("testStringifySHA")) + addToSuite = staticmethod(addToSuite) + + + +def main(): + suite = unittest.TestSuite() + SegmentBaseInitTestCase.addToSuite(suite) + DataSegmentTestCase.addToSuite(suite) + SHAUtilsTestCase.addToSuite(suite) + + runner = unittest.TextTestRunner() + runner.run(suite) + + + def got_data(addr, data, user_data=None): print "Data (%s): %s" % (addr, data) -def main(): +def foobar(): pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data) pipe.start() pipe.send('The quick brown fox jumps over the lazy dog')