test_payload.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # -*- coding: utf-8 -*-
  2. """
  3. :codeauthor: Pedro Algarvio (pedro@algarvio.me)
  4. tests.unit.payload_test
  5. ~~~~~~~~~~~~~~~~~~~~~~~
  6. """
  7. from __future__ import absolute_import, print_function, unicode_literals
  8. import copy
  9. import datetime
  10. import errno
  11. import logging
  12. import threading
  13. import time
  14. import salt.exceptions
  15. import salt.payload
  16. import zmq
  17. from salt.ext import six
  18. from salt.utils import immutabletypes
  19. from salt.utils.odict import OrderedDict
  20. from tests.support.helpers import slowTest
  21. from tests.support.unit import TestCase, skipIf
  22. log = logging.getLogger(__name__)
  23. class PayloadTestCase(TestCase):
  24. def assertNoOrderedDict(self, data):
  25. if isinstance(data, OrderedDict):
  26. raise AssertionError("Found an ordered dictionary")
  27. if isinstance(data, dict):
  28. for value in six.itervalues(data):
  29. self.assertNoOrderedDict(value)
  30. elif isinstance(data, (list, tuple)):
  31. for chunk in data:
  32. self.assertNoOrderedDict(chunk)
  33. def test_list_nested_odicts(self):
  34. payload = salt.payload.Serial("msgpack")
  35. idata = {"pillar": [OrderedDict(environment="dev")]}
  36. odata = payload.loads(payload.dumps(idata.copy()))
  37. self.assertNoOrderedDict(odata)
  38. self.assertEqual(idata, odata)
  39. def test_datetime_dump_load(self):
  40. """
  41. Check the custom datetime handler can understand itself
  42. """
  43. payload = salt.payload.Serial("msgpack")
  44. dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
  45. idata = {dtvalue: dtvalue}
  46. sdata = payload.dumps(idata.copy())
  47. odata = payload.loads(sdata)
  48. self.assertEqual(
  49. sdata,
  50. b"\x81\xc7\x18N20010203T04:05:06.000007\xc7\x18N20010203T04:05:06.000007",
  51. )
  52. self.assertEqual(idata, odata)
  53. def test_verylong_dump_load(self):
  54. """
  55. Test verylong encoder/decoder
  56. """
  57. payload = salt.payload.Serial("msgpack")
  58. idata = {"jid": 20180227140750302662}
  59. sdata = payload.dumps(idata.copy())
  60. odata = payload.loads(sdata)
  61. idata["jid"] = "{0}".format(idata["jid"])
  62. self.assertEqual(idata, odata)
  63. def test_immutable_dict_dump_load(self):
  64. """
  65. Test immutable dict encoder/decoder
  66. """
  67. payload = salt.payload.Serial("msgpack")
  68. idata = {"dict": {"key": "value"}}
  69. sdata = payload.dumps({"dict": immutabletypes.ImmutableDict(idata["dict"])})
  70. odata = payload.loads(sdata)
  71. self.assertEqual(idata, odata)
  72. def test_immutable_list_dump_load(self):
  73. """
  74. Test immutable list encoder/decoder
  75. """
  76. payload = salt.payload.Serial("msgpack")
  77. idata = {"list": [1, 2, 3]}
  78. sdata = payload.dumps({"list": immutabletypes.ImmutableList(idata["list"])})
  79. odata = payload.loads(sdata)
  80. self.assertEqual(idata, odata)
  81. def test_immutable_set_dump_load(self):
  82. """
  83. Test immutable set encoder/decoder
  84. """
  85. payload = salt.payload.Serial("msgpack")
  86. idata = {"set": ["red", "green", "blue"]}
  87. sdata = payload.dumps({"set": immutabletypes.ImmutableSet(idata["set"])})
  88. odata = payload.loads(sdata)
  89. self.assertEqual(idata, odata)
  90. def test_odict_dump_load(self):
  91. """
  92. Test odict just works. It wasn't until msgpack 0.2.0
  93. """
  94. payload = salt.payload.Serial("msgpack")
  95. data = OrderedDict()
  96. data["a"] = "b"
  97. data["y"] = "z"
  98. data["j"] = "k"
  99. data["w"] = "x"
  100. sdata = payload.dumps({"set": data})
  101. odata = payload.loads(sdata)
  102. self.assertEqual({"set": dict(data)}, odata)
  103. def test_mixed_dump_load(self):
  104. """
  105. Test we can handle all exceptions at once
  106. """
  107. payload = salt.payload.Serial("msgpack")
  108. dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
  109. od = OrderedDict()
  110. od["a"] = "b"
  111. od["y"] = "z"
  112. od["j"] = "k"
  113. od["w"] = "x"
  114. idata = {
  115. dtvalue: dtvalue, # datetime
  116. "jid": 20180227140750302662, # long int
  117. "dict": immutabletypes.ImmutableDict({"key": "value"}), # immutable dict
  118. "list": immutabletypes.ImmutableList([1, 2, 3]), # immutable list
  119. "set": immutabletypes.ImmutableSet(
  120. ("red", "green", "blue")
  121. ), # immutable set
  122. "odict": od, # odict
  123. }
  124. edata = {
  125. dtvalue: dtvalue, # datetime, == input
  126. "jid": "20180227140750302662", # string repr of long int
  127. "dict": {"key": "value"}, # builtin dict
  128. "list": [1, 2, 3], # builtin list
  129. "set": ["red", "green", "blue"], # builtin set
  130. "odict": dict(od), # builtin dict
  131. }
  132. sdata = payload.dumps(idata)
  133. odata = payload.loads(sdata)
  134. self.assertEqual(edata, odata)
  135. def test_recursive_dump_load(self):
  136. """
  137. Test recursive payloads are (mostly) serialized
  138. """
  139. payload = salt.payload.Serial("msgpack")
  140. data = {"name": "roscivs"}
  141. data["data"] = data # Data all the things!
  142. sdata = payload.dumps(data)
  143. odata = payload.loads(sdata)
  144. self.assertTrue("recursion" in odata["data"].lower())
  145. def test_recursive_dump_load_with_identical_non_recursive_types(self):
  146. """
  147. If identical objects are nested anywhere, they should not be
  148. marked recursive unless they're one of the types we iterate
  149. over.
  150. """
  151. payload = salt.payload.Serial("msgpack")
  152. repeating = "repeating element"
  153. data = {
  154. "a": "a", # Test CPython implementation detail. Short
  155. "b": "a", # strings are interned.
  156. "c": 13, # So are small numbers.
  157. "d": 13,
  158. "fnord": repeating,
  159. # Let's go for broke and make a crazy nested structure
  160. "repeating": [
  161. [[[[{"one": repeating, "two": repeating}], repeating, 13, "a"]]],
  162. repeating,
  163. repeating,
  164. repeating,
  165. ],
  166. }
  167. # We need a nested dictionary to trigger the exception
  168. data["repeating"][0][0][0].append(data)
  169. # If we don't deepcopy the data it gets mutated
  170. sdata = payload.dumps(copy.deepcopy(data))
  171. odata = payload.loads(sdata)
  172. # Delete the recursive piece - it's served its purpose, and our
  173. # other test tests that it's actually marked as recursive.
  174. del odata["repeating"][0][0][0][-1], data["repeating"][0][0][0][-1]
  175. self.assertDictEqual(odata, data)
  176. class SREQTestCase(TestCase):
  177. port = 8845 # TODO: dynamically assign a port?
  178. @classmethod
  179. def setUpClass(cls):
  180. """
  181. Class to set up zmq echo socket
  182. """
  183. def echo_server():
  184. """
  185. A server that echos the message sent to it over zmq
  186. Optional "sleep" can be sent to delay response
  187. """
  188. context = zmq.Context()
  189. socket = context.socket(zmq.REP)
  190. socket.bind("tcp://*:{0}".format(SREQTestCase.port))
  191. payload = salt.payload.Serial("msgpack")
  192. while SREQTestCase.thread_running.is_set():
  193. try:
  194. # Wait for next request from client
  195. message = socket.recv(zmq.NOBLOCK)
  196. msg_deserialized = payload.loads(message)
  197. log.info("Echo server received message: %s", msg_deserialized)
  198. if isinstance(msg_deserialized["load"], dict) and msg_deserialized[
  199. "load"
  200. ].get("sleep"):
  201. log.info(
  202. "Test echo server sleeping for %s seconds",
  203. msg_deserialized["load"]["sleep"],
  204. )
  205. time.sleep(msg_deserialized["load"]["sleep"])
  206. socket.send(message)
  207. except zmq.ZMQError as exc:
  208. if exc.errno == errno.EAGAIN:
  209. continue
  210. raise
  211. SREQTestCase.thread_running = threading.Event()
  212. SREQTestCase.thread_running.set()
  213. SREQTestCase.echo_server = threading.Thread(target=echo_server)
  214. SREQTestCase.echo_server.start()
  215. @classmethod
  216. def tearDownClass(cls):
  217. """
  218. Remove echo server
  219. """
  220. # kill the thread
  221. SREQTestCase.thread_running.clear()
  222. SREQTestCase.echo_server.join()
  223. def get_sreq(self):
  224. return salt.payload.SREQ("tcp://127.0.0.1:{0}".format(SREQTestCase.port))
  225. @slowTest
  226. def test_send_auto(self):
  227. """
  228. Test creation, send/rect
  229. """
  230. sreq = self.get_sreq()
  231. # check default of empty load and enc clear
  232. assert sreq.send_auto({}) == {"enc": "clear", "load": {}}
  233. # check that the load always gets passed
  234. assert sreq.send_auto({"load": "foo"}) == {"load": "foo", "enc": "clear"}
  235. def test_send(self):
  236. sreq = self.get_sreq()
  237. assert sreq.send("clear", "foo") == {"enc": "clear", "load": "foo"}
  238. @skipIf(True, "Disabled until we can figure out how to make this more reliable.")
  239. def test_timeout(self):
  240. """
  241. Test SREQ Timeouts
  242. """
  243. sreq = self.get_sreq()
  244. # client-side timeout
  245. start = time.time()
  246. # This is a try/except instead of an assertRaises because of a possible
  247. # subtle bug in zmq wherein a timeout=0 actually exceutes a single poll
  248. # before the timeout is reached.
  249. log.info("Sending tries=0, timeout=0")
  250. try:
  251. sreq.send("clear", "foo", tries=0, timeout=0)
  252. except salt.exceptions.SaltReqTimeoutError:
  253. pass
  254. assert time.time() - start < 1 # ensure we didn't wait
  255. # server-side timeout
  256. log.info("Sending tries=1, timeout=1")
  257. start = time.time()
  258. with self.assertRaises(salt.exceptions.SaltReqTimeoutError):
  259. sreq.send("clear", {"sleep": 2}, tries=1, timeout=1)
  260. assert time.time() - start >= 1 # ensure we actually tried once (1s)
  261. # server-side timeout with retries
  262. log.info("Sending tries=2, timeout=1")
  263. start = time.time()
  264. with self.assertRaises(salt.exceptions.SaltReqTimeoutError):
  265. sreq.send("clear", {"sleep": 2}, tries=2, timeout=1)
  266. assert time.time() - start >= 2 # ensure we actually tried twice (2s)
  267. # test a regular send afterwards (to make sure sockets aren't in a twist
  268. log.info("Sending regular send")
  269. assert sreq.send("clear", "foo") == {"enc": "clear", "load": "foo"}
  270. def test_destroy(self):
  271. """
  272. Test the __del__ capabilities
  273. """
  274. sreq = self.get_sreq()
  275. # ensure no exceptions when we go to destroy the sreq, since __del__
  276. # swallows exceptions, we have to call destroy directly
  277. sreq.destroy()
  278. def test_raw_vs_encoding_none(self):
  279. """
  280. Test that we handle the new raw parameter in 5.0.2 correctly based on
  281. encoding. When encoding is None loads should return bytes
  282. """
  283. payload = salt.payload.Serial("msgpack")
  284. dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
  285. idata = {dtvalue: "strval"}
  286. sdata = payload.dumps(idata.copy())
  287. odata = payload.loads(sdata, encoding=None)
  288. assert isinstance(odata[dtvalue], six.string_types)
  289. def test_raw_vs_encoding_utf8(self):
  290. """
  291. Test that we handle the new raw parameter in 5.0.2 correctly based on
  292. encoding. When encoding is utf-8 loads should return unicode
  293. """
  294. payload = salt.payload.Serial("msgpack")
  295. dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
  296. idata = {dtvalue: "strval"}
  297. sdata = payload.dumps(idata.copy())
  298. odata = payload.loads(sdata, encoding="utf-8")
  299. assert isinstance(odata[dtvalue], six.text_type)