1
0

test_msgpack.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. """
  2. Test the MessagePack utility
  3. """
  4. import inspect
  5. import os
  6. import pprint
  7. import struct
  8. import sys
  9. from io import BytesIO
  10. import salt.utils.msgpack
  11. from salt.ext.six.moves import range
  12. from salt.utils.odict import OrderedDict
  13. from tests.support.unit import TestCase, skipIf
  14. try:
  15. import msgpack
  16. except ImportError:
  17. import msgpack_pure as msgpack # pylint: disable=import-error
  18. # A keyword to pass to tests that use `raw`, which was added in msgpack 0.5.2
  19. raw = {"raw": False} if msgpack.version > (0, 5, 2) else {}
  20. @skipIf(not salt.utils.msgpack.HAS_MSGPACK, "msgpack module required for these tests")
  21. class TestMsgpack(TestCase):
  22. """
  23. In msgpack, the following aliases exist:
  24. load = unpack
  25. loads = unpackb
  26. dump = pack
  27. dumps = packb
  28. The salt.utils.msgpack versions of these functions are not aliases,
  29. verify that they pass the same relevant tests from:
  30. https://github.com/msgpack/msgpack-python/blob/master/test/
  31. """
  32. test_data = [
  33. 0,
  34. 1,
  35. 127,
  36. 128,
  37. 255,
  38. 256,
  39. 65535,
  40. 65536,
  41. 4294967295,
  42. 4294967296,
  43. -1,
  44. -32,
  45. -33,
  46. -128,
  47. -129,
  48. -32768,
  49. -32769,
  50. -4294967296,
  51. -4294967297,
  52. 1.0,
  53. b"",
  54. b"a",
  55. b"a" * 31,
  56. b"a" * 32,
  57. None,
  58. True,
  59. False,
  60. (),
  61. ((),),
  62. ((), None,),
  63. {None: 0},
  64. (1 << 23),
  65. ]
  66. def test_version(self):
  67. """
  68. Verify that the version exists and returns a value in the expected format
  69. """
  70. version = salt.utils.msgpack.version
  71. self.assertTrue(isinstance(version, tuple))
  72. self.assertGreater(version, (0, 0, 0))
  73. def test_Packer(self):
  74. data = os.urandom(1024)
  75. packer = salt.utils.msgpack.Packer()
  76. unpacker = msgpack.Unpacker(None)
  77. packed = packer.pack(data)
  78. # Sanity Check
  79. self.assertTrue(packed)
  80. self.assertNotEqual(data, packed)
  81. # Reverse the packing and the result should be equivalent to the original data
  82. unpacker.feed(packed)
  83. unpacked = msgpack.unpackb(packed)
  84. self.assertEqual(data, unpacked)
  85. def test_Unpacker(self):
  86. data = os.urandom(1024)
  87. packer = msgpack.Packer()
  88. unpacker = salt.utils.msgpack.Unpacker(None)
  89. packed = packer.pack(data)
  90. # Sanity Check
  91. self.assertTrue(packed)
  92. self.assertNotEqual(data, packed)
  93. # Reverse the packing and the result should be equivalent to the original data
  94. unpacker.feed(packed)
  95. unpacked = msgpack.unpackb(packed)
  96. self.assertEqual(data, unpacked)
  97. def test_array_size(self):
  98. sizes = [0, 5, 50, 1000]
  99. bio = BytesIO()
  100. packer = salt.utils.msgpack.Packer()
  101. for size in sizes:
  102. bio.write(packer.pack_array_header(size))
  103. for i in range(size):
  104. bio.write(packer.pack(i))
  105. bio.seek(0)
  106. unpacker = salt.utils.msgpack.Unpacker(bio, use_list=True)
  107. for size in sizes:
  108. self.assertEqual(unpacker.unpack(), list(range(size)))
  109. def test_manual_reset(self):
  110. sizes = [0, 5, 50, 1000]
  111. packer = salt.utils.msgpack.Packer(autoreset=False)
  112. for size in sizes:
  113. packer.pack_array_header(size)
  114. for i in range(size):
  115. packer.pack(i)
  116. bio = BytesIO(packer.bytes())
  117. unpacker = salt.utils.msgpack.Unpacker(bio, use_list=True)
  118. for size in sizes:
  119. self.assertEqual(unpacker.unpack(), list(range(size)))
  120. packer.reset()
  121. self.assertEqual(packer.bytes(), b"")
  122. def test_map_size(self):
  123. sizes = [0, 5, 50, 1000]
  124. bio = BytesIO()
  125. packer = salt.utils.msgpack.Packer()
  126. for size in sizes:
  127. bio.write(packer.pack_map_header(size))
  128. for i in range(size):
  129. bio.write(packer.pack(i)) # key
  130. bio.write(packer.pack(i * 2)) # value
  131. bio.seek(0)
  132. if salt.utils.msgpack.version > (0, 6, 0):
  133. unpacker = salt.utils.msgpack.Unpacker(bio, strict_map_key=False)
  134. else:
  135. unpacker = salt.utils.msgpack.Unpacker(bio)
  136. for size in sizes:
  137. self.assertEqual(unpacker.unpack(), {i: i * 2 for i in range(size)})
  138. def test_max_buffer_size(self):
  139. """
  140. Test if max buffer size allows at least 100MiB
  141. """
  142. bio = BytesIO()
  143. bio.write(salt.utils.msgpack.packb("0" * (100 * 1024 * 1024)))
  144. bio.seek(0)
  145. unpacker = salt.utils.msgpack.Unpacker(bio)
  146. raised = False
  147. try:
  148. unpacker.unpack()
  149. except ValueError:
  150. raised = True
  151. self.assertFalse(raised)
  152. def test_exceptions(self):
  153. # Verify that this exception exists
  154. self.assertTrue(salt.utils.msgpack.exceptions.PackValueError)
  155. self.assertTrue(salt.utils.msgpack.exceptions.UnpackValueError)
  156. self.assertTrue(salt.utils.msgpack.exceptions.PackValueError)
  157. self.assertTrue(salt.utils.msgpack.exceptions.UnpackValueError)
  158. def test_function_aliases(self):
  159. """
  160. Fail if core functionality from msgpack is missing in the utility
  161. """
  162. def sanitized(item):
  163. if inspect.isfunction(getattr(msgpack, item)):
  164. # Only check objects that exist in the same file as msgpack
  165. return inspect.getfile(getattr(msgpack, item)) == inspect.getfile(
  166. msgpack
  167. )
  168. msgpack_items = {
  169. x for x in dir(msgpack) if not x.startswith("_") and sanitized(x)
  170. }
  171. msgpack_util_items = set(dir(salt.utils.msgpack))
  172. self.assertFalse(
  173. msgpack_items - msgpack_util_items,
  174. "msgpack functions with no alias in `salt.utils.msgpack`",
  175. )
  176. def _test_base(self, pack_func, unpack_func):
  177. """
  178. In msgpack, 'dumps' is an alias for 'packb' and 'loads' is an alias for 'unpackb'.
  179. Verify that both salt.utils.msgpack function variations pass the exact same test
  180. """
  181. data = os.urandom(1024)
  182. packed = pack_func(data)
  183. # Sanity Check
  184. self.assertTrue(packed)
  185. self.assertIsInstance(packed, bytes)
  186. self.assertNotEqual(data, packed)
  187. # Reverse the packing and the result should be equivalent to the original data
  188. unpacked = unpack_func(packed)
  189. self.assertEqual(data, unpacked)
  190. def _test_buffered_base(self, pack_func, unpack_func):
  191. data = os.urandom(1024).decode(errors="ignore")
  192. buffer = BytesIO()
  193. # Sanity check, we are not borking the BytesIO read function
  194. self.assertNotEqual(BytesIO.read, buffer.read)
  195. buffer.read = buffer.getvalue
  196. pack_func(data, buffer)
  197. # Sanity Check
  198. self.assertTrue(buffer.getvalue())
  199. self.assertIsInstance(buffer.getvalue(), bytes)
  200. self.assertNotEqual(data, buffer.getvalue())
  201. # Reverse the packing and the result should be equivalent to the original data
  202. unpacked = unpack_func(buffer)
  203. if isinstance(unpacked, bytes):
  204. unpacked = unpacked.decode()
  205. self.assertEqual(data, unpacked)
  206. def test_buffered_base_pack(self):
  207. self._test_buffered_base(
  208. pack_func=salt.utils.msgpack.pack, unpack_func=msgpack.unpack
  209. )
  210. def test_buffered_base_unpack(self):
  211. self._test_buffered_base(
  212. pack_func=msgpack.pack, unpack_func=salt.utils.msgpack.unpack
  213. )
  214. def _test_unpack_array_header_from_file(self, pack_func, **kwargs):
  215. f = BytesIO(pack_func([1, 2, 3, 4]))
  216. unpacker = salt.utils.msgpack.Unpacker(f)
  217. self.assertEqual(unpacker.read_array_header(), 4)
  218. self.assertEqual(unpacker.unpack(), 1)
  219. self.assertEqual(unpacker.unpack(), 2)
  220. self.assertEqual(unpacker.unpack(), 3)
  221. self.assertEqual(unpacker.unpack(), 4)
  222. self.assertRaises(salt.utils.msgpack.exceptions.OutOfData, unpacker.unpack)
  223. @skipIf(
  224. not hasattr(sys, "getrefcount"), "sys.getrefcount() is needed to pass this test"
  225. )
  226. def _test_unpacker_hook_refcnt(self, pack_func, **kwargs):
  227. result = []
  228. def hook(x):
  229. result.append(x)
  230. return x
  231. basecnt = sys.getrefcount(hook)
  232. up = salt.utils.msgpack.Unpacker(object_hook=hook, list_hook=hook)
  233. self.assertGreaterEqual(sys.getrefcount(hook), basecnt + 2)
  234. up.feed(pack_func([{}]))
  235. up.feed(pack_func([{}]))
  236. self.assertEqual(up.unpack(), [{}])
  237. self.assertEqual(up.unpack(), [{}])
  238. self.assertEqual(result, [{}, [{}], {}, [{}]])
  239. del up
  240. self.assertEqual(sys.getrefcount(hook), basecnt)
  241. def _test_unpacker_ext_hook(self, pack_func, **kwargs):
  242. class MyUnpacker(salt.utils.msgpack.Unpacker):
  243. def __init__(self):
  244. my_kwargs = {}
  245. super().__init__(ext_hook=self._hook, **raw)
  246. def _hook(self, code, data):
  247. if code == 1:
  248. return int(data)
  249. else:
  250. return salt.utils.msgpack.ExtType(code, data)
  251. unpacker = MyUnpacker()
  252. unpacker.feed(pack_func({"a": 1}))
  253. self.assertEqual(unpacker.unpack(), {"a": 1})
  254. unpacker.feed(pack_func({"a": salt.utils.msgpack.ExtType(1, b"123")}))
  255. self.assertEqual(unpacker.unpack(), {"a": 123})
  256. unpacker.feed(pack_func({"a": salt.utils.msgpack.ExtType(2, b"321")}))
  257. self.assertEqual(
  258. unpacker.unpack(), {"a": salt.utils.msgpack.ExtType(2, b"321")}
  259. )
  260. def _check(
  261. self, data, pack_func, unpack_func, use_list=False, strict_map_key=False
  262. ):
  263. my_kwargs = {}
  264. if salt.utils.msgpack.version >= (0, 6, 0):
  265. my_kwargs["strict_map_key"] = strict_map_key
  266. ret = unpack_func(pack_func(data), use_list=use_list, **my_kwargs)
  267. self.assertEqual(ret, data)
  268. def _test_pack_unicode(self, pack_func, unpack_func):
  269. test_data = ["", "abcd", ["defgh"], "Русский текст"]
  270. for td in test_data:
  271. ret = unpack_func(pack_func(td), use_list=True, **raw)
  272. self.assertEqual(ret, td)
  273. packer = salt.utils.msgpack.Packer()
  274. data = packer.pack(td)
  275. ret = salt.utils.msgpack.Unpacker(
  276. BytesIO(data), use_list=True, **raw
  277. ).unpack()
  278. self.assertEqual(ret, td)
  279. def _test_pack_bytes(self, pack_func, unpack_func):
  280. test_data = [
  281. b"",
  282. b"abcd",
  283. (b"defgh",),
  284. ]
  285. for td in test_data:
  286. self._check(td, pack_func, unpack_func)
  287. def _test_pack_byte_arrays(self, pack_func, unpack_func):
  288. test_data = [
  289. bytearray(b""),
  290. bytearray(b"abcd"),
  291. (bytearray(b"defgh"),),
  292. ]
  293. for td in test_data:
  294. self._check(td, pack_func, unpack_func)
  295. @skipIf(sys.version_info < (3, 0), "Python 2 passes invalid surrogates")
  296. def _test_ignore_unicode_errors(self, pack_func, unpack_func):
  297. ret = unpack_func(
  298. pack_func(b"abc\xeddef", use_bin_type=False), unicode_errors="ignore", **raw
  299. )
  300. self.assertEqual("abcdef", ret)
  301. def _test_strict_unicode_unpack(self, pack_func, unpack_func):
  302. packed = pack_func(b"abc\xeddef", use_bin_type=False)
  303. self.assertRaises(UnicodeDecodeError, unpack_func, packed, use_list=True, **raw)
  304. @skipIf(sys.version_info < (3, 0), "Python 2 passes invalid surrogates")
  305. def _test_ignore_errors_pack(self, pack_func, unpack_func):
  306. ret = unpack_func(
  307. pack_func("abc\uDC80\uDCFFdef", use_bin_type=True, unicode_errors="ignore"),
  308. use_list=True,
  309. **raw
  310. )
  311. self.assertEqual("abcdef", ret)
  312. def _test_decode_binary(self, pack_func, unpack_func):
  313. ret = unpack_func(pack_func(b"abc"), use_list=True)
  314. self.assertEqual(b"abc", ret)
  315. @skipIf(
  316. salt.utils.msgpack.version < (0, 2, 2),
  317. "use_single_float was added in msgpack==0.2.2",
  318. )
  319. def _test_pack_float(self, pack_func, **kwargs):
  320. self.assertEqual(
  321. b"\xca" + struct.pack(">f", 1.0), pack_func(1.0, use_single_float=True)
  322. )
  323. self.assertEqual(
  324. b"\xcb" + struct.pack(">d", 1.0), pack_func(1.0, use_single_float=False),
  325. )
  326. def _test_odict(self, pack_func, unpack_func):
  327. seq = [(b"one", 1), (b"two", 2), (b"three", 3), (b"four", 4)]
  328. od = OrderedDict(seq)
  329. self.assertEqual(dict(seq), unpack_func(pack_func(od), use_list=True))
  330. def pair_hook(seq):
  331. return list(seq)
  332. self.assertEqual(
  333. seq, unpack_func(pack_func(od), object_pairs_hook=pair_hook, use_list=True)
  334. )
  335. def _test_pair_list(self, unpack_func, **kwargs):
  336. pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")]
  337. packer = salt.utils.msgpack.Packer()
  338. packed = packer.pack_map_pairs(pairlist)
  339. if salt.utils.msgpack.version > (0, 6, 0):
  340. unpacked = unpack_func(packed, object_pairs_hook=list, strict_map_key=False)
  341. else:
  342. unpacked = unpack_func(packed, object_pairs_hook=list)
  343. self.assertEqual(pairlist, unpacked)
  344. @skipIf(
  345. salt.utils.msgpack.version < (0, 6, 0),
  346. "getbuffer() was added to Packer in msgpack 0.6.0",
  347. )
  348. def _test_get_buffer(self, pack_func, **kwargs):
  349. packer = msgpack.Packer(autoreset=False, use_bin_type=True)
  350. packer.pack([1, 2])
  351. strm = BytesIO()
  352. strm.write(packer.getbuffer())
  353. written = strm.getvalue()
  354. expected = pack_func([1, 2], use_bin_type=True)
  355. self.assertEqual(expected, written)
  356. @staticmethod
  357. def no_fail_run(test, *args, **kwargs):
  358. """
  359. Run a test without failure and return any exception it raises
  360. """
  361. try:
  362. test(*args, **kwargs)
  363. except Exception as e: # pylint: disable=broad-except
  364. return e
  365. def test_binary_function_compatibility(self):
  366. functions = [
  367. {"pack_func": salt.utils.msgpack.packb, "unpack_func": msgpack.unpackb},
  368. {"pack_func": msgpack.packb, "unpack_func": salt.utils.msgpack.unpackb},
  369. ]
  370. # These functions are equivalent but could potentially be overwritten
  371. if salt.utils.msgpack.dumps is not salt.utils.msgpack.packb:
  372. functions.append(
  373. {"pack_func": salt.utils.msgpack.dumps, "unpack_func": msgpack.unpackb}
  374. )
  375. if salt.utils.msgpack.loads is not salt.utils.msgpack.unpackb:
  376. functions.append(
  377. {"pack_func": msgpack.packb, "unpack_func": salt.utils.msgpack.loads}
  378. )
  379. test_funcs = (
  380. self._test_base,
  381. self._test_unpack_array_header_from_file,
  382. self._test_unpacker_hook_refcnt,
  383. self._test_unpacker_ext_hook,
  384. self._test_pack_unicode,
  385. self._test_pack_bytes,
  386. self._test_pack_byte_arrays,
  387. self._test_ignore_unicode_errors,
  388. self._test_strict_unicode_unpack,
  389. self._test_ignore_errors_pack,
  390. self._test_decode_binary,
  391. self._test_pack_float,
  392. self._test_odict,
  393. self._test_pair_list,
  394. self._test_get_buffer,
  395. )
  396. errors = {}
  397. for test_func in test_funcs:
  398. # Run the test without the salt.utils.msgpack module for comparison
  399. vanilla_run = self.no_fail_run(
  400. test_func,
  401. **{"pack_func": msgpack.packb, "unpack_func": msgpack.unpackb}
  402. )
  403. for func_args in functions:
  404. func_name = (
  405. func_args["pack_func"]
  406. if func_args["pack_func"].__module__.startswith("salt.utils")
  407. else func_args["unpack_func"]
  408. )
  409. if hasattr(TestCase, "subTest"):
  410. with self.subTest(test=test_func.__name__, func=func_name.__name__):
  411. # Run the test with the salt.utils.msgpack module
  412. run = self.no_fail_run(test_func, **func_args)
  413. # If the vanilla msgpack module errored, then skip if we got the same error
  414. if run:
  415. if str(vanilla_run) == str(run):
  416. self.skipTest(
  417. "Failed the same way as the vanilla msgpack module:\n{}".format(
  418. run
  419. )
  420. )
  421. else:
  422. # If subTest isn't available then run the tests collect the errors of all the tests before failing
  423. run = self.no_fail_run(test_func, **func_args)
  424. if run:
  425. # If the vanilla msgpack module errored, then skip if we got the same error
  426. if str(vanilla_run) == str(run):
  427. self.skipTest(
  428. "Test failed the same way the vanilla msgpack module fails:\n{}".format(
  429. run
  430. )
  431. )
  432. else:
  433. errors[(test_func.__name__, func_name.__name__)] = run
  434. if errors:
  435. self.fail(pprint.pformat(errors))