123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503 |
- """
- Test the MessagePack utility
- """
- import inspect
- import os
- import pprint
- import struct
- import sys
- from io import BytesIO
- import salt.utils.msgpack
- from salt.ext.six.moves import range
- from salt.utils.odict import OrderedDict
- from tests.support.unit import TestCase, skipIf
- try:
- import msgpack
- except ImportError:
- import msgpack_pure as msgpack # pylint: disable=import-error
- # A keyword to pass to tests that use `raw`, which was added in msgpack 0.5.2
- raw = {"raw": False} if msgpack.version > (0, 5, 2) else {}
- @skipIf(not salt.utils.msgpack.HAS_MSGPACK, "msgpack module required for these tests")
- class TestMsgpack(TestCase):
- """
- In msgpack, the following aliases exist:
- load = unpack
- loads = unpackb
- dump = pack
- dumps = packb
- The salt.utils.msgpack versions of these functions are not aliases,
- verify that they pass the same relevant tests from:
- https://github.com/msgpack/msgpack-python/blob/master/test/
- """
- test_data = [
- 0,
- 1,
- 127,
- 128,
- 255,
- 256,
- 65535,
- 65536,
- 4294967295,
- 4294967296,
- -1,
- -32,
- -33,
- -128,
- -129,
- -32768,
- -32769,
- -4294967296,
- -4294967297,
- 1.0,
- b"",
- b"a",
- b"a" * 31,
- b"a" * 32,
- None,
- True,
- False,
- (),
- ((),),
- ((), None,),
- {None: 0},
- (1 << 23),
- ]
- def test_version(self):
- """
- Verify that the version exists and returns a value in the expected format
- """
- version = salt.utils.msgpack.version
- self.assertTrue(isinstance(version, tuple))
- self.assertGreater(version, (0, 0, 0))
- def test_Packer(self):
- data = os.urandom(1024)
- packer = salt.utils.msgpack.Packer()
- unpacker = msgpack.Unpacker(None)
- packed = packer.pack(data)
- # Sanity Check
- self.assertTrue(packed)
- self.assertNotEqual(data, packed)
- # Reverse the packing and the result should be equivalent to the original data
- unpacker.feed(packed)
- unpacked = msgpack.unpackb(packed)
- self.assertEqual(data, unpacked)
- def test_Unpacker(self):
- data = os.urandom(1024)
- packer = msgpack.Packer()
- unpacker = salt.utils.msgpack.Unpacker(None)
- packed = packer.pack(data)
- # Sanity Check
- self.assertTrue(packed)
- self.assertNotEqual(data, packed)
- # Reverse the packing and the result should be equivalent to the original data
- unpacker.feed(packed)
- unpacked = msgpack.unpackb(packed)
- self.assertEqual(data, unpacked)
- def test_array_size(self):
- sizes = [0, 5, 50, 1000]
- bio = BytesIO()
- packer = salt.utils.msgpack.Packer()
- for size in sizes:
- bio.write(packer.pack_array_header(size))
- for i in range(size):
- bio.write(packer.pack(i))
- bio.seek(0)
- unpacker = salt.utils.msgpack.Unpacker(bio, use_list=True)
- for size in sizes:
- self.assertEqual(unpacker.unpack(), list(range(size)))
- def test_manual_reset(self):
- sizes = [0, 5, 50, 1000]
- packer = salt.utils.msgpack.Packer(autoreset=False)
- for size in sizes:
- packer.pack_array_header(size)
- for i in range(size):
- packer.pack(i)
- bio = BytesIO(packer.bytes())
- unpacker = salt.utils.msgpack.Unpacker(bio, use_list=True)
- for size in sizes:
- self.assertEqual(unpacker.unpack(), list(range(size)))
- packer.reset()
- self.assertEqual(packer.bytes(), b"")
- def test_map_size(self):
- sizes = [0, 5, 50, 1000]
- bio = BytesIO()
- packer = salt.utils.msgpack.Packer()
- for size in sizes:
- bio.write(packer.pack_map_header(size))
- for i in range(size):
- bio.write(packer.pack(i)) # key
- bio.write(packer.pack(i * 2)) # value
- bio.seek(0)
- if salt.utils.msgpack.version > (0, 6, 0):
- unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False)
- else:
- unpacker = salt.utils.msgpack.Unpacker(bio)
- for size in sizes:
- self.assertEqual(unpacker.unpack(), {i: i * 2 for i in range(size)})
- def test_max_buffer_size(self):
- """
- Test if max buffer size allows at least 100MiB
- """
- bio = BytesIO()
- bio.write(salt.utils.msgpack.packb("0" * (100 * 1024 * 1024)))
- bio.seek(0)
- unpacker = salt.utils.msgpack.Unpacker(bio)
- raised = False
- try:
- unpacker.unpack()
- except ValueError:
- raised = True
- self.assertFalse(raised)
- def test_exceptions(self):
- # Verify that this exception exists
- self.assertTrue(salt.utils.msgpack.exceptions.PackValueError)
- self.assertTrue(salt.utils.msgpack.exceptions.UnpackValueError)
- self.assertTrue(salt.utils.msgpack.exceptions.PackValueError)
- self.assertTrue(salt.utils.msgpack.exceptions.UnpackValueError)
- def test_function_aliases(self):
- """
- Fail if core functionality from msgpack is missing in the utility
- """
- def sanitized(item):
- if inspect.isfunction(getattr(msgpack, item)):
- # Only check objects that exist in the same file as msgpack
- return inspect.getfile(getattr(msgpack, item)) == inspect.getfile(
- msgpack
- )
- msgpack_items = {
- x for x in dir(msgpack) if not x.startswith("_") and sanitized(x)
- }
- msgpack_util_items = set(dir(salt.utils.msgpack))
- self.assertFalse(
- msgpack_items - msgpack_util_items,
- "msgpack functions with no alias in `salt.utils.msgpack`",
- )
- def _test_base(self, pack_func, unpack_func):
- """
- In msgpack, 'dumps' is an alias for 'packb' and 'loads' is an alias for 'unpackb'.
- Verify that both salt.utils.msgpack function variations pass the exact same test
- """
- data = os.urandom(1024)
- packed = pack_func(data)
- # Sanity Check
- self.assertTrue(packed)
- self.assertIsInstance(packed, bytes)
- self.assertNotEqual(data, packed)
- # Reverse the packing and the result should be equivalent to the original data
- unpacked = unpack_func(packed)
- self.assertEqual(data, unpacked)
- def _test_buffered_base(self, pack_func, unpack_func):
- data = os.urandom(1024).decode(errors="ignore")
- buffer = BytesIO()
- # Sanity check, we are not borking the BytesIO read function
- self.assertNotEqual(BytesIO.read, buffer.read)
- buffer.read = buffer.getvalue
- pack_func(data, buffer)
- # Sanity Check
- self.assertTrue(buffer.getvalue())
- self.assertIsInstance(buffer.getvalue(), bytes)
- self.assertNotEqual(data, buffer.getvalue())
- # Reverse the packing and the result should be equivalent to the original data
- unpacked = unpack_func(buffer)
- if isinstance(unpacked, bytes):
- unpacked = unpacked.decode()
- self.assertEqual(data, unpacked)
- def test_buffered_base_pack(self):
- self._test_buffered_base(
- pack_func=salt.utils.msgpack.pack, unpack_func=msgpack.unpack
- )
- def test_buffered_base_unpack(self):
- self._test_buffered_base(
- pack_func=msgpack.pack, unpack_func=salt.utils.msgpack.unpack
- )
- def _test_unpack_array_header_from_file(self, pack_func, **kwargs):
- f = BytesIO(pack_func([1, 2, 3, 4]))
- unpacker = salt.utils.msgpack.Unpacker(f)
- self.assertEqual(unpacker.read_array_header(), 4)
- self.assertEqual(unpacker.unpack(), 1)
- self.assertEqual(unpacker.unpack(), 2)
- self.assertEqual(unpacker.unpack(), 3)
- self.assertEqual(unpacker.unpack(), 4)
- self.assertRaises(salt.utils.msgpack.exceptions.OutOfData, unpacker.unpack)
- @skipIf(
- not hasattr(sys, "getrefcount"), "sys.getrefcount() is needed to pass this test"
- )
- def _test_unpacker_hook_refcnt(self, pack_func, **kwargs):
- result = []
- def hook(x):
- result.append(x)
- return x
- basecnt = sys.getrefcount(hook)
- up = salt.utils.msgpack.Unpacker(object_hook=hook, list_hook=hook)
- self.assertGreaterEqual(sys.getrefcount(hook), basecnt + 2)
- up.feed(pack_func([{}]))
- up.feed(pack_func([{}]))
- self.assertEqual(up.unpack(), [{}])
- self.assertEqual(up.unpack(), [{}])
- self.assertEqual(result, [{}, [{}], {}, [{}]])
- del up
- self.assertEqual(sys.getrefcount(hook), basecnt)
- def _test_unpacker_ext_hook(self, pack_func, **kwargs):
- class MyUnpacker(salt.utils.msgpack.Unpacker):
- def __init__(self):
- my_kwargs = {}
- super().__init__(ext_hook=self._hook, **raw)
- def _hook(self, code, data):
- if code == 1:
- return int(data)
- else:
- return salt.utils.msgpack.ExtType(code, data)
- unpacker = MyUnpacker()
- unpacker.feed(pack_func({"a": 1}))
- self.assertEqual(unpacker.unpack(), {"a": 1})
- unpacker.feed(pack_func({"a": salt.utils.msgpack.ExtType(1, b"123")}))
- self.assertEqual(unpacker.unpack(), {"a": 123})
- unpacker.feed(pack_func({"a": salt.utils.msgpack.ExtType(2, b"321")}))
- self.assertEqual(
- unpacker.unpack(), {"a": salt.utils.msgpack.ExtType(2, b"321")}
- )
- def _check(
- self, data, pack_func, unpack_func, use_list=False, strict_map_key=False
- ):
- my_kwargs = {}
- if salt.utils.msgpack.version >= (0, 6, 0):
- my_kwargs["strict_map_key"] = strict_map_key
- ret = unpack_func(pack_func(data), use_list=use_list, **my_kwargs)
- self.assertEqual(ret, data)
- def _test_pack_unicode(self, pack_func, unpack_func):
- test_data = ["", "abcd", ["defgh"], "Русский текст"]
- for td in test_data:
- ret = unpack_func(pack_func(td), use_list=True, **raw)
- self.assertEqual(ret, td)
- packer = salt.utils.msgpack.Packer()
- data = packer.pack(td)
- ret = salt.utils.msgpack.Unpacker(
- BytesIO(data), use_list=True, **raw
- ).unpack()
- self.assertEqual(ret, td)
- def _test_pack_bytes(self, pack_func, unpack_func):
- test_data = [
- b"",
- b"abcd",
- (b"defgh",),
- ]
- for td in test_data:
- self._check(td, pack_func, unpack_func)
- def _test_pack_byte_arrays(self, pack_func, unpack_func):
- test_data = [
- bytearray(b""),
- bytearray(b"abcd"),
- (bytearray(b"defgh"),),
- ]
- for td in test_data:
- self._check(td, pack_func, unpack_func)
- @skipIf(sys.version_info < (3, 0), "Python 2 passes invalid surrogates")
- def _test_ignore_unicode_errors(self, pack_func, unpack_func):
- ret = unpack_func(
- pack_func(b"abc\xeddef", use_bin_type=False), unicode_errors="ignore", **raw
- )
- self.assertEqual("abcdef", ret)
- def _test_strict_unicode_unpack(self, pack_func, unpack_func):
- packed = pack_func(b"abc\xeddef", use_bin_type=False)
- self.assertRaises(UnicodeDecodeError, unpack_func, packed, use_list=True, **raw)
- @skipIf(sys.version_info < (3, 0), "Python 2 passes invalid surrogates")
- def _test_ignore_errors_pack(self, pack_func, unpack_func):
- ret = unpack_func(
- pack_func("abc\uDC80\uDCFFdef", use_bin_type=True, unicode_errors="ignore"),
- use_list=True,
- **raw
- )
- self.assertEqual("abcdef", ret)
- def _test_decode_binary(self, pack_func, unpack_func):
- ret = unpack_func(pack_func(b"abc"), use_list=True)
- self.assertEqual(b"abc", ret)
- @skipIf(
- salt.utils.msgpack.version < (0, 2, 2),
- "use_single_float was added in msgpack==0.2.2",
- )
- def _test_pack_float(self, pack_func, **kwargs):
- self.assertEqual(
- b"\xca" + struct.pack(">f", 1.0), pack_func(1.0, use_single_float=True)
- )
- self.assertEqual(
- b"\xcb" + struct.pack(">d", 1.0), pack_func(1.0, use_single_float=False),
- )
- def _test_odict(self, pack_func, unpack_func):
- seq = [(b"one", 1), (b"two", 2), (b"three", 3), (b"four", 4)]
- od = OrderedDict(seq)
- self.assertEqual(dict(seq), unpack_func(pack_func(od), use_list=True))
- def pair_hook(seq):
- return list(seq)
- self.assertEqual(
- seq, unpack_func(pack_func(od), object_pairs_hook=pair_hook, use_list=True)
- )
- def _test_pair_list(self, unpack_func, **kwargs):
- pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")]
- packer = salt.utils.msgpack.Packer()
- packed = packer.pack_map_pairs(pairlist)
- if salt.utils.msgpack.version > (0, 6, 0):
- unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False)
- else:
- unpacked = unpack_func(packed, object_pairs_hook=list)
- self.assertEqual(pairlist, unpacked)
- @skipIf(
- salt.utils.msgpack.version < (0, 6, 0),
- "getbuffer() was added to Packer in msgpack 0.6.0",
- )
- def _test_get_buffer(self, pack_func, **kwargs):
- packer = msgpack.Packer(autoreset=False, use_bin_type=True)
- packer.pack([1, 2])
- strm = BytesIO()
- strm.write(packer.getbuffer())
- written = strm.getvalue()
- expected = pack_func([1, 2], use_bin_type=True)
- self.assertEqual(expected, written)
- @staticmethod
- def no_fail_run(test, *args, **kwargs):
- """
- Run a test without failure and return any exception it raises
- """
- try:
- test(*args, **kwargs)
- except Exception as e: # pylint: disable=broad-except
- return e
- def test_binary_function_compatibility(self):
- functions = [
- {"pack_func": salt.utils.msgpack.packb, "unpack_func": msgpack.unpackb},
- {"pack_func": msgpack.packb, "unpack_func": salt.utils.msgpack.unpackb},
- ]
- # These functions are equivalent but could potentially be overwritten
- if salt.utils.msgpack.dumps is not salt.utils.msgpack.packb:
- functions.append(
- {"pack_func": salt.utils.msgpack.dumps, "unpack_func": msgpack.unpackb}
- )
- if salt.utils.msgpack.loads is not salt.utils.msgpack.unpackb:
- functions.append(
- {"pack_func": msgpack.packb, "unpack_func": salt.utils.msgpack.loads}
- )
- test_funcs = (
- self._test_base,
- self._test_unpack_array_header_from_file,
- self._test_unpacker_hook_refcnt,
- self._test_unpacker_ext_hook,
- self._test_pack_unicode,
- self._test_pack_bytes,
- self._test_pack_byte_arrays,
- self._test_ignore_unicode_errors,
- self._test_strict_unicode_unpack,
- self._test_ignore_errors_pack,
- self._test_decode_binary,
- self._test_pack_float,
- self._test_odict,
- self._test_pair_list,
- self._test_get_buffer,
- )
- errors = {}
- for test_func in test_funcs:
- # Run the test without the salt.utils.msgpack module for comparison
- vanilla_run = self.no_fail_run(
- test_func,
- **{"pack_func": msgpack.packb, "unpack_func": msgpack.unpackb}
- )
- for func_args in functions:
- func_name = (
- func_args["pack_func"]
- if func_args["pack_func"].__module__.startswith("salt.utils")
- else func_args["unpack_func"]
- )
- if hasattr(TestCase, "subTest"):
- with self.subTest(test=test_func.__name__, func=func_name.__name__):
- # Run the test with the salt.utils.msgpack module
- run = self.no_fail_run(test_func, **func_args)
- # If the vanilla msgpack module errored, then skip if we got the same error
- if run:
- if str(vanilla_run) == str(run):
- self.skipTest(
- "Failed the same way as the vanilla msgpack module:\n{}".format(
- run
- )
- )
- else:
- # If subTest isn't available then run the tests collect the errors of all the tests before failing
- run = self.no_fail_run(test_func, **func_args)
- if run:
- # If the vanilla msgpack module errored, then skip if we got the same error
- if str(vanilla_run) == str(run):
- self.skipTest(
- "Test failed the same way the vanilla msgpack module fails:\n{}".format(
- run
- )
- )
- else:
- errors[(test_func.__name__, func_name.__name__)] = run
- if errors:
- self.fail(pprint.pformat(errors))
|