1
0

test_tcp.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. """
  2. :codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
  3. """
  4. import logging
  5. import socket
  6. import threading
  7. import salt.config
  8. import salt.exceptions
  9. import salt.ext.tornado.concurrent
  10. import salt.ext.tornado.gen
  11. import salt.ext.tornado.ioloop
  12. import salt.transport.client
  13. import salt.transport.server
  14. import salt.utils.platform
  15. import salt.utils.process
  16. from salt.ext.tornado.testing import AsyncTestCase, gen_test
  17. from salt.transport.tcp import (
  18. SaltMessageClient,
  19. SaltMessageClientPool,
  20. TCPPubServerChannel,
  21. )
  22. from saltfactories.utils.ports import get_unused_localhost_port
  23. from tests.support.helpers import flaky, slowTest
  24. from tests.support.mixins import AdaptedConfigurationTestCaseMixin
  25. from tests.support.mock import MagicMock, patch
  26. from tests.support.unit import TestCase, skipIf
  27. from tests.unit.transport.mixins import (
  28. PubChannelMixin,
  29. ReqChannelMixin,
  30. run_loop_in_thread,
  31. )
  32. log = logging.getLogger(__name__)
  33. class BaseTCPReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
  34. """
  35. Test the req server/client pair
  36. """
  37. @classmethod
  38. def setUpClass(cls):
  39. if not hasattr(cls, "_handle_payload"):
  40. return
  41. ret_port = get_unused_localhost_port()
  42. publish_port = get_unused_localhost_port()
  43. tcp_master_pub_port = get_unused_localhost_port()
  44. tcp_master_pull_port = get_unused_localhost_port()
  45. tcp_master_publish_pull = get_unused_localhost_port()
  46. tcp_master_workers = get_unused_localhost_port()
  47. cls.master_config = cls.get_temp_config(
  48. "master",
  49. **{
  50. "transport": "tcp",
  51. "auto_accept": True,
  52. "ret_port": ret_port,
  53. "publish_port": publish_port,
  54. "tcp_master_pub_port": tcp_master_pub_port,
  55. "tcp_master_pull_port": tcp_master_pull_port,
  56. "tcp_master_publish_pull": tcp_master_publish_pull,
  57. "tcp_master_workers": tcp_master_workers,
  58. }
  59. )
  60. cls.minion_config = cls.get_temp_config(
  61. "minion",
  62. **{
  63. "transport": "tcp",
  64. "master_ip": "127.0.0.1",
  65. "master_port": ret_port,
  66. "master_uri": "tcp://127.0.0.1:{}".format(ret_port),
  67. }
  68. )
  69. cls.process_manager = salt.utils.process.ProcessManager(
  70. name="ReqServer_ProcessManager"
  71. )
  72. cls.server_channel = salt.transport.server.ReqServerChannel.factory(
  73. cls.master_config
  74. )
  75. cls.server_channel.pre_fork(cls.process_manager)
  76. cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
  77. cls.stop = threading.Event()
  78. cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  79. cls.server_thread = threading.Thread(
  80. target=run_loop_in_thread, args=(cls.io_loop, cls.stop,),
  81. )
  82. cls.server_thread.start()
  83. @classmethod
  84. def tearDownClass(cls):
  85. cls.server_channel.close()
  86. cls.stop.set()
  87. cls.server_thread.join()
  88. cls.process_manager.kill_children()
  89. del cls.server_channel
  90. @classmethod
  91. @salt.ext.tornado.gen.coroutine
  92. def _handle_payload(cls, payload):
  93. """
  94. TODO: something besides echo
  95. """
  96. raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
  97. @skipIf(salt.utils.platform.is_darwin(), "hanging test suite on MacOS")
  98. class ClearReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  99. """
  100. Test all of the clear msg stuff
  101. """
  102. def setUp(self):
  103. self.channel = salt.transport.client.ReqChannel.factory(
  104. self.minion_config, crypt="clear"
  105. )
  106. def tearDown(self):
  107. self.channel.close()
  108. del self.channel
  109. @classmethod
  110. @salt.ext.tornado.gen.coroutine
  111. def _handle_payload(cls, payload):
  112. """
  113. TODO: something besides echo
  114. """
  115. raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
  116. @skipIf(salt.utils.platform.is_darwin(), "hanging test suite on MacOS")
  117. class AESReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  118. def setUp(self):
  119. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
  120. def tearDown(self):
  121. self.channel.close()
  122. del self.channel
  123. @classmethod
  124. @salt.ext.tornado.gen.coroutine
  125. def _handle_payload(cls, payload):
  126. """
  127. TODO: something besides echo
  128. """
  129. raise salt.ext.tornado.gen.Return((payload, {"fun": "send"}))
  130. # TODO: make failed returns have a specific framing so we can raise the same exception
  131. # on encrypted channels
  132. @flaky
  133. @slowTest
  134. def test_badload(self):
  135. """
  136. Test a variety of bad requests, make sure that we get some sort of error
  137. """
  138. msgs = ["", [], tuple()]
  139. for msg in msgs:
  140. with self.assertRaises(salt.exceptions.AuthenticationError):
  141. ret = self.channel.send(msg)
  142. class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
  143. """
  144. Test the req server/client pair
  145. """
  146. @classmethod
  147. def setUpClass(cls):
  148. ret_port = get_unused_localhost_port()
  149. publish_port = get_unused_localhost_port()
  150. tcp_master_pub_port = get_unused_localhost_port()
  151. tcp_master_pull_port = get_unused_localhost_port()
  152. tcp_master_publish_pull = get_unused_localhost_port()
  153. tcp_master_workers = get_unused_localhost_port()
  154. cls.master_config = cls.get_temp_config(
  155. "master",
  156. **{
  157. "transport": "tcp",
  158. "auto_accept": True,
  159. "ret_port": ret_port,
  160. "publish_port": publish_port,
  161. "tcp_master_pub_port": tcp_master_pub_port,
  162. "tcp_master_pull_port": tcp_master_pull_port,
  163. "tcp_master_publish_pull": tcp_master_publish_pull,
  164. "tcp_master_workers": tcp_master_workers,
  165. }
  166. )
  167. cls.minion_config = cls.get_temp_config(
  168. "minion",
  169. **{
  170. "transport": "tcp",
  171. "master_ip": "127.0.0.1",
  172. "auth_timeout": 1,
  173. "master_port": ret_port,
  174. "master_uri": "tcp://127.0.0.1:{}".format(ret_port),
  175. }
  176. )
  177. cls.process_manager = salt.utils.process.ProcessManager(
  178. name="ReqServer_ProcessManager"
  179. )
  180. cls.server_channel = salt.transport.server.PubServerChannel.factory(
  181. cls.master_config
  182. )
  183. cls.server_channel.pre_fork(cls.process_manager)
  184. # we also require req server for auth
  185. cls.req_server_channel = salt.transport.server.ReqServerChannel.factory(
  186. cls.master_config
  187. )
  188. cls.req_server_channel.pre_fork(cls.process_manager)
  189. cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
  190. cls.stop = threading.Event()
  191. cls.req_server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  192. cls.server_thread = threading.Thread(
  193. target=run_loop_in_thread, args=(cls.io_loop, cls.stop,),
  194. )
  195. cls.server_thread.start()
  196. @classmethod
  197. def _handle_payload(cls, payload):
  198. """
  199. TODO: something besides echo
  200. """
  201. return payload, {"fun": "send_clear"}
  202. @classmethod
  203. def tearDownClass(cls):
  204. cls.req_server_channel.close()
  205. cls.server_channel.close()
  206. cls.stop.set()
  207. cls.server_thread.join()
  208. cls.process_manager.kill_children()
  209. del cls.req_server_channel
  210. def setUp(self):
  211. super().setUp()
  212. self._start_handlers = dict(self.io_loop._handlers)
  213. def tearDown(self):
  214. super().tearDown()
  215. failures = []
  216. for k, v in self.io_loop._handlers.items():
  217. if self._start_handlers.get(k) != v:
  218. failures.append((k, v))
  219. if failures:
  220. raise Exception("FDs still attached to the IOLoop: {}".format(failures))
  221. del self.channel
  222. del self._start_handlers
  223. class AsyncTCPPubChannelTest(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
  224. @slowTest
  225. def test_connect_publish_port(self):
  226. """
  227. test when publish_port is not 4506
  228. """
  229. opts = self.get_temp_config("master")
  230. opts["master_uri"] = ""
  231. opts["master_ip"] = "127.0.0.1"
  232. opts["publish_port"] = 1234
  233. channel = salt.transport.tcp.AsyncTCPPubChannel(opts)
  234. patch_auth = MagicMock(return_value=True)
  235. patch_client = MagicMock(spec=SaltMessageClientPool)
  236. with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
  237. "salt.crypt.AsyncAuth.authenticated", patch_auth
  238. ), patch("salt.transport.tcp.SaltMessageClientPool", patch_client):
  239. channel.connect()
  240. assert patch_client.call_args[0][0]["publish_port"] == opts["publish_port"]
  241. @skipIf(True, "Skip until we can devote time to fix this test")
  242. class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
  243. """
  244. Tests around the publish system
  245. """
  246. class SaltMessageClientPoolTest(AsyncTestCase):
  247. def setUp(self):
  248. super().setUp()
  249. sock_pool_size = 5
  250. with patch(
  251. "salt.transport.tcp.SaltMessageClient.__init__",
  252. MagicMock(return_value=None),
  253. ):
  254. self.message_client_pool = SaltMessageClientPool(
  255. {"sock_pool_size": sock_pool_size}, args=({}, "", 0)
  256. )
  257. self.original_message_clients = self.message_client_pool.message_clients
  258. self.message_client_pool.message_clients = [
  259. MagicMock() for _ in range(sock_pool_size)
  260. ]
  261. def tearDown(self):
  262. with patch(
  263. "salt.transport.tcp.SaltMessageClient.close", MagicMock(return_value=None)
  264. ):
  265. del self.original_message_clients
  266. super().tearDown()
  267. def test_send(self):
  268. for message_client_mock in self.message_client_pool.message_clients:
  269. message_client_mock.send_queue = [0, 0, 0]
  270. message_client_mock.send.return_value = []
  271. self.assertEqual([], self.message_client_pool.send())
  272. self.message_client_pool.message_clients[2].send_queue = [0]
  273. self.message_client_pool.message_clients[2].send.return_value = [1]
  274. self.assertEqual([1], self.message_client_pool.send())
  275. def test_write_to_stream(self):
  276. for message_client_mock in self.message_client_pool.message_clients:
  277. message_client_mock.send_queue = [0, 0, 0]
  278. message_client_mock._stream.write.return_value = []
  279. self.assertEqual([], self.message_client_pool.write_to_stream(""))
  280. self.message_client_pool.message_clients[2].send_queue = [0]
  281. self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
  282. self.assertEqual([1], self.message_client_pool.write_to_stream(""))
  283. def test_close(self):
  284. self.message_client_pool.close()
  285. self.assertEqual([], self.message_client_pool.message_clients)
  286. def test_on_recv(self):
  287. for message_client_mock in self.message_client_pool.message_clients:
  288. message_client_mock.on_recv.return_value = None
  289. self.message_client_pool.on_recv()
  290. for message_client_mock in self.message_client_pool.message_clients:
  291. self.assertTrue(message_client_mock.on_recv.called)
  292. def test_connect_all(self):
  293. @gen_test
  294. def test_connect(self):
  295. yield self.message_client_pool.connect()
  296. for message_client_mock in self.message_client_pool.message_clients:
  297. future = salt.ext.tornado.concurrent.Future()
  298. future.set_result("foo")
  299. message_client_mock.connect.return_value = future
  300. self.assertIsNone(test_connect(self))
  301. def test_connect_partial(self):
  302. @gen_test(timeout=0.1)
  303. def test_connect(self):
  304. yield self.message_client_pool.connect()
  305. for idx, message_client_mock in enumerate(
  306. self.message_client_pool.message_clients
  307. ):
  308. future = salt.ext.tornado.concurrent.Future()
  309. if idx % 2 == 0:
  310. future.set_result("foo")
  311. message_client_mock.connect.return_value = future
  312. with self.assertRaises(salt.ext.tornado.ioloop.TimeoutError):
  313. test_connect(self)
  314. class SaltMessageClientCleanupTest(TestCase, AdaptedConfigurationTestCaseMixin):
  315. def setUp(self):
  316. self.listen_on = "127.0.0.1"
  317. self.port = get_unused_localhost_port()
  318. self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  319. self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  320. self.sock.bind((self.listen_on, self.port))
  321. self.sock.listen(1)
  322. def tearDown(self):
  323. self.sock.close()
  324. del self.sock
  325. def test_message_client(self):
  326. """
  327. test message client cleanup on close
  328. """
  329. orig_loop = salt.ext.tornado.ioloop.IOLoop()
  330. orig_loop.make_current()
  331. opts = self.get_temp_config("master")
  332. client = SaltMessageClient(opts, self.listen_on, self.port)
  333. # Mock the io_loop's stop method so we know when it has been called.
  334. orig_loop.real_stop = orig_loop.stop
  335. orig_loop.stop_called = False
  336. def stop(*args, **kwargs):
  337. orig_loop.stop_called = True
  338. orig_loop.real_stop()
  339. orig_loop.stop = stop
  340. try:
  341. assert client.io_loop == orig_loop
  342. client.io_loop.run_sync(client.connect)
  343. # Ensure we are testing the _read_until_future and io_loop teardown
  344. assert client._stream is not None
  345. assert client._read_until_future is not None
  346. assert orig_loop.stop_called is True
  347. # The run_sync call will set stop_called, reset it
  348. orig_loop.stop_called = False
  349. client.close()
  350. # Stop should be called again, client's io_loop should be None
  351. assert orig_loop.stop_called is True
  352. assert client.io_loop is None
  353. finally:
  354. orig_loop.stop = orig_loop.real_stop
  355. del orig_loop.real_stop
  356. del orig_loop.stop_called
  357. class TCPPubServerChannelTest(TestCase, AdaptedConfigurationTestCaseMixin):
  358. @patch("salt.master.SMaster.secrets")
  359. @patch("salt.crypt.Crypticle")
  360. @patch("salt.utils.asynchronous.SyncWrapper")
  361. def test_publish_filtering(self, sync_wrapper, crypticle, secrets):
  362. opts = self.get_temp_config("master")
  363. opts["sign_pub_messages"] = False
  364. channel = TCPPubServerChannel(opts)
  365. wrap = MagicMock()
  366. crypt = MagicMock()
  367. crypt.dumps.return_value = {"test": "value"}
  368. secrets.return_value = {"aes": {"secret": None}}
  369. crypticle.return_value = crypt
  370. sync_wrapper.return_value = wrap
  371. # try simple publish with glob tgt_type
  372. channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
  373. payload = wrap.send.call_args[0][0]
  374. # verify we send it without any specific topic
  375. assert "topic_lst" not in payload
  376. # try simple publish with list tgt_type
  377. channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
  378. payload = wrap.send.call_args[0][0]
  379. # verify we send it with correct topic
  380. assert "topic_lst" in payload
  381. self.assertEqual(payload["topic_lst"], ["minion01"])
  382. # try with syndic settings
  383. opts["order_masters"] = True
  384. channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
  385. payload = wrap.send.call_args[0][0]
  386. # verify we send it without topic for syndics
  387. assert "topic_lst" not in payload
  388. @patch("salt.utils.minions.CkMinions.check_minions")
  389. @patch("salt.master.SMaster.secrets")
  390. @patch("salt.crypt.Crypticle")
  391. @patch("salt.utils.asynchronous.SyncWrapper")
  392. def test_publish_filtering_str_list(
  393. self, sync_wrapper, crypticle, secrets, check_minions
  394. ):
  395. opts = self.get_temp_config("master")
  396. opts["sign_pub_messages"] = False
  397. channel = TCPPubServerChannel(opts)
  398. wrap = MagicMock()
  399. crypt = MagicMock()
  400. crypt.dumps.return_value = {"test": "value"}
  401. secrets.return_value = {"aes": {"secret": None}}
  402. crypticle.return_value = crypt
  403. sync_wrapper.return_value = wrap
  404. check_minions.return_value = {"minions": ["minion02"]}
  405. # try simple publish with list tgt_type
  406. channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
  407. payload = wrap.send.call_args[0][0]
  408. # verify we send it with correct topic
  409. assert "topic_lst" in payload
  410. self.assertEqual(payload["topic_lst"], ["minion02"])
  411. # verify it was correctly calling check_minions
  412. check_minions.assert_called_with("minion02", tgt_type="list")