1
0

test_tcp.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # -*- coding: utf-8 -*-
  2. '''
  3. :codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
  4. '''
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import threading
  8. import socket
  9. import logging
  10. import salt.ext.tornado.gen
  11. import salt.ext.tornado.ioloop
  12. import salt.ext.tornado.concurrent
  13. from salt.ext.tornado.testing import AsyncTestCase, gen_test
  14. import salt.config
  15. from salt.ext import six
  16. import salt.utils.platform
  17. import salt.utils.process
  18. import salt.transport.server
  19. import salt.transport.client
  20. import salt.exceptions
  21. from salt.ext.six.moves import range
  22. from salt.transport.tcp import SaltMessageClientPool, SaltMessageClient, TCPPubServerChannel
  23. # Import Salt Testing libs
  24. from tests.support.unit import TestCase, skipIf
  25. from tests.support.helpers import get_unused_localhost_port, flaky
  26. from tests.support.mixins import AdaptedConfigurationTestCaseMixin
  27. from tests.support.mock import MagicMock, patch
  28. from tests.unit.transport.mixins import PubChannelMixin, ReqChannelMixin, run_loop_in_thread
  29. log = logging.getLogger(__name__)
  30. class BaseTCPReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
  31. '''
  32. Test the req server/client pair
  33. '''
  34. @classmethod
  35. def setUpClass(cls):
  36. if not hasattr(cls, '_handle_payload'):
  37. return
  38. ret_port = get_unused_localhost_port()
  39. publish_port = get_unused_localhost_port()
  40. tcp_master_pub_port = get_unused_localhost_port()
  41. tcp_master_pull_port = get_unused_localhost_port()
  42. tcp_master_publish_pull = get_unused_localhost_port()
  43. tcp_master_workers = get_unused_localhost_port()
  44. cls.master_config = cls.get_temp_config(
  45. 'master',
  46. **{'transport': 'tcp',
  47. 'auto_accept': True,
  48. 'ret_port': ret_port,
  49. 'publish_port': publish_port,
  50. 'tcp_master_pub_port': tcp_master_pub_port,
  51. 'tcp_master_pull_port': tcp_master_pull_port,
  52. 'tcp_master_publish_pull': tcp_master_publish_pull,
  53. 'tcp_master_workers': tcp_master_workers}
  54. )
  55. cls.minion_config = cls.get_temp_config(
  56. 'minion',
  57. **{'transport': 'tcp',
  58. 'master_ip': '127.0.0.1',
  59. 'master_port': ret_port,
  60. 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
  61. )
  62. cls.process_manager = salt.utils.process.ProcessManager(name='ReqServer_ProcessManager')
  63. cls.server_channel = salt.transport.server.ReqServerChannel.factory(cls.master_config)
  64. cls.server_channel.pre_fork(cls.process_manager)
  65. cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
  66. cls.stop = threading.Event()
  67. cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  68. cls.server_thread = threading.Thread(
  69. target=run_loop_in_thread,
  70. args=(cls.io_loop, cls.stop,),
  71. )
  72. cls.server_thread.start()
  73. @classmethod
  74. def tearDownClass(cls):
  75. cls.server_channel.close()
  76. cls.stop.set()
  77. cls.server_thread.join()
  78. cls.process_manager.kill_children()
  79. del cls.server_channel
  80. @classmethod
  81. @salt.ext.tornado.gen.coroutine
  82. def _handle_payload(cls, payload):
  83. '''
  84. TODO: something besides echo
  85. '''
  86. raise salt.ext.tornado.gen.Return((payload, {'fun': 'send_clear'}))
  87. @skipIf(salt.utils.platform.is_darwin(), 'hanging test suite on MacOS')
  88. class ClearReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  89. '''
  90. Test all of the clear msg stuff
  91. '''
  92. def setUp(self):
  93. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config, crypt='clear')
  94. def tearDown(self):
  95. self.channel.close()
  96. del self.channel
  97. @classmethod
  98. @salt.ext.tornado.gen.coroutine
  99. def _handle_payload(cls, payload):
  100. '''
  101. TODO: something besides echo
  102. '''
  103. raise salt.ext.tornado.gen.Return((payload, {'fun': 'send_clear'}))
  104. @skipIf(salt.utils.platform.is_darwin(), 'hanging test suite on MacOS')
  105. class AESReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  106. def setUp(self):
  107. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
  108. def tearDown(self):
  109. self.channel.close()
  110. del self.channel
  111. @classmethod
  112. @salt.ext.tornado.gen.coroutine
  113. def _handle_payload(cls, payload):
  114. '''
  115. TODO: something besides echo
  116. '''
  117. raise salt.ext.tornado.gen.Return((payload, {'fun': 'send'}))
  118. # TODO: make failed returns have a specific framing so we can raise the same exception
  119. # on encrypted channels
  120. @flaky
  121. def test_badload(self):
  122. '''
  123. Test a variety of bad requests, make sure that we get some sort of error
  124. '''
  125. msgs = ['', [], tuple()]
  126. for msg in msgs:
  127. with self.assertRaises(salt.exceptions.AuthenticationError):
  128. ret = self.channel.send(msg)
  129. class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
  130. '''
  131. Test the req server/client pair
  132. '''
  133. @classmethod
  134. def setUpClass(cls):
  135. ret_port = get_unused_localhost_port()
  136. publish_port = get_unused_localhost_port()
  137. tcp_master_pub_port = get_unused_localhost_port()
  138. tcp_master_pull_port = get_unused_localhost_port()
  139. tcp_master_publish_pull = get_unused_localhost_port()
  140. tcp_master_workers = get_unused_localhost_port()
  141. cls.master_config = cls.get_temp_config(
  142. 'master',
  143. **{'transport': 'tcp',
  144. 'auto_accept': True,
  145. 'ret_port': ret_port,
  146. 'publish_port': publish_port,
  147. 'tcp_master_pub_port': tcp_master_pub_port,
  148. 'tcp_master_pull_port': tcp_master_pull_port,
  149. 'tcp_master_publish_pull': tcp_master_publish_pull,
  150. 'tcp_master_workers': tcp_master_workers}
  151. )
  152. cls.minion_config = cls.get_temp_config(
  153. 'minion',
  154. **{'transport': 'tcp',
  155. 'master_ip': '127.0.0.1',
  156. 'auth_timeout': 1,
  157. 'master_port': ret_port,
  158. 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
  159. )
  160. cls.process_manager = salt.utils.process.ProcessManager(name='ReqServer_ProcessManager')
  161. cls.server_channel = salt.transport.server.PubServerChannel.factory(cls.master_config)
  162. cls.server_channel.pre_fork(cls.process_manager)
  163. # we also require req server for auth
  164. cls.req_server_channel = salt.transport.server.ReqServerChannel.factory(cls.master_config)
  165. cls.req_server_channel.pre_fork(cls.process_manager)
  166. cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
  167. cls.stop = threading.Event()
  168. cls.req_server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  169. cls.server_thread = threading.Thread(
  170. target=run_loop_in_thread,
  171. args=(cls.io_loop, cls.stop,),
  172. )
  173. cls.server_thread.start()
  174. @classmethod
  175. def _handle_payload(cls, payload):
  176. '''
  177. TODO: something besides echo
  178. '''
  179. return payload, {'fun': 'send_clear'}
  180. @classmethod
  181. def tearDownClass(cls):
  182. cls.req_server_channel.close()
  183. cls.server_channel.close()
  184. cls.stop.set()
  185. cls.server_thread.join()
  186. cls.process_manager.kill_children()
  187. del cls.req_server_channel
  188. def setUp(self):
  189. super(BaseTCPPubCase, self).setUp()
  190. self._start_handlers = dict(self.io_loop._handlers)
  191. def tearDown(self):
  192. super(BaseTCPPubCase, self).tearDown()
  193. failures = []
  194. for k, v in six.iteritems(self.io_loop._handlers):
  195. if self._start_handlers.get(k) != v:
  196. failures.append((k, v))
  197. if failures:
  198. raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  199. del self.channel
  200. del self._start_handlers
  201. class AsyncTCPPubChannelTest(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
  202. def test_connect_publish_port(self):
  203. '''
  204. test when publish_port is not 4506
  205. '''
  206. opts = self.get_temp_config('master')
  207. opts['master_uri'] = ''
  208. opts['master_ip'] = '127.0.0.1'
  209. opts['publish_port'] = 1234
  210. channel = salt.transport.tcp.AsyncTCPPubChannel(opts)
  211. patch_auth = MagicMock(return_value=True)
  212. patch_client = MagicMock(spec=SaltMessageClientPool)
  213. with patch('salt.crypt.AsyncAuth.gen_token', patch_auth), \
  214. patch('salt.crypt.AsyncAuth.authenticated', patch_auth), \
  215. patch('salt.transport.tcp.SaltMessageClientPool',
  216. patch_client):
  217. channel.connect()
  218. assert patch_client.call_args[0][0]['publish_port'] == opts['publish_port']
  219. @skipIf(True, 'Skip until we can devote time to fix this test')
  220. class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
  221. '''
  222. Tests around the publish system
  223. '''
  224. class SaltMessageClientPoolTest(AsyncTestCase):
  225. def setUp(self):
  226. super(SaltMessageClientPoolTest, self).setUp()
  227. sock_pool_size = 5
  228. with patch('salt.transport.tcp.SaltMessageClient.__init__', MagicMock(return_value=None)):
  229. self.message_client_pool = SaltMessageClientPool({'sock_pool_size': sock_pool_size},
  230. args=({}, '', 0))
  231. self.original_message_clients = self.message_client_pool.message_clients
  232. self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
  233. def tearDown(self):
  234. with patch('salt.transport.tcp.SaltMessageClient.close', MagicMock(return_value=None)):
  235. del self.original_message_clients
  236. super(SaltMessageClientPoolTest, self).tearDown()
  237. def test_send(self):
  238. for message_client_mock in self.message_client_pool.message_clients:
  239. message_client_mock.send_queue = [0, 0, 0]
  240. message_client_mock.send.return_value = []
  241. self.assertEqual([], self.message_client_pool.send())
  242. self.message_client_pool.message_clients[2].send_queue = [0]
  243. self.message_client_pool.message_clients[2].send.return_value = [1]
  244. self.assertEqual([1], self.message_client_pool.send())
  245. def test_write_to_stream(self):
  246. for message_client_mock in self.message_client_pool.message_clients:
  247. message_client_mock.send_queue = [0, 0, 0]
  248. message_client_mock._stream.write.return_value = []
  249. self.assertEqual([], self.message_client_pool.write_to_stream(''))
  250. self.message_client_pool.message_clients[2].send_queue = [0]
  251. self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
  252. self.assertEqual([1], self.message_client_pool.write_to_stream(''))
  253. def test_close(self):
  254. self.message_client_pool.close()
  255. self.assertEqual([], self.message_client_pool.message_clients)
  256. def test_on_recv(self):
  257. for message_client_mock in self.message_client_pool.message_clients:
  258. message_client_mock.on_recv.return_value = None
  259. self.message_client_pool.on_recv()
  260. for message_client_mock in self.message_client_pool.message_clients:
  261. self.assertTrue(message_client_mock.on_recv.called)
  262. def test_connect_all(self):
  263. @gen_test
  264. def test_connect(self):
  265. yield self.message_client_pool.connect()
  266. for message_client_mock in self.message_client_pool.message_clients:
  267. future = salt.ext.tornado.concurrent.Future()
  268. future.set_result('foo')
  269. message_client_mock.connect.return_value = future
  270. self.assertIsNone(test_connect(self))
  271. def test_connect_partial(self):
  272. @gen_test(timeout=0.1)
  273. def test_connect(self):
  274. yield self.message_client_pool.connect()
  275. for idx, message_client_mock in enumerate(self.message_client_pool.message_clients):
  276. future = salt.ext.tornado.concurrent.Future()
  277. if idx % 2 == 0:
  278. future.set_result('foo')
  279. message_client_mock.connect.return_value = future
  280. with self.assertRaises(salt.ext.tornado.ioloop.TimeoutError):
  281. test_connect(self)
  282. class SaltMessageClientCleanupTest(TestCase, AdaptedConfigurationTestCaseMixin):
  283. def setUp(self):
  284. self.listen_on = '127.0.0.1'
  285. self.port = get_unused_localhost_port()
  286. self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  287. self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  288. self.sock.bind((self.listen_on, self.port))
  289. self.sock.listen(1)
  290. def tearDown(self):
  291. self.sock.close()
  292. del self.sock
  293. def test_message_client(self):
  294. '''
  295. test message client cleanup on close
  296. '''
  297. orig_loop = salt.ext.tornado.ioloop.IOLoop()
  298. orig_loop.make_current()
  299. opts = self.get_temp_config('master')
  300. client = SaltMessageClient(opts, self.listen_on, self.port)
  301. # Mock the io_loop's stop method so we know when it has been called.
  302. orig_loop.real_stop = orig_loop.stop
  303. orig_loop.stop_called = False
  304. def stop(*args, **kwargs):
  305. orig_loop.stop_called = True
  306. orig_loop.real_stop()
  307. orig_loop.stop = stop
  308. try:
  309. assert client.io_loop == orig_loop
  310. client.io_loop.run_sync(client.connect)
  311. # Ensure we are testing the _read_until_future and io_loop teardown
  312. assert client._stream is not None
  313. assert client._read_until_future is not None
  314. assert orig_loop.stop_called is True
  315. # The run_sync call will set stop_called, reset it
  316. orig_loop.stop_called = False
  317. client.close()
  318. # Stop should be called again, client's io_loop should be None
  319. assert orig_loop.stop_called is True
  320. assert client.io_loop is None
  321. finally:
  322. orig_loop.stop = orig_loop.real_stop
  323. del orig_loop.real_stop
  324. del orig_loop.stop_called
  325. class TCPPubServerChannelTest(TestCase, AdaptedConfigurationTestCaseMixin):
  326. @patch('salt.master.SMaster.secrets')
  327. @patch('salt.crypt.Crypticle')
  328. @patch('salt.utils.asynchronous.SyncWrapper')
  329. def test_publish_filtering(self, sync_wrapper, crypticle, secrets):
  330. opts = self.get_temp_config('master')
  331. opts["sign_pub_messages"] = False
  332. channel = TCPPubServerChannel(opts)
  333. wrap = MagicMock()
  334. crypt = MagicMock()
  335. crypt.dumps.return_value = {"test": "value"}
  336. secrets.return_value = {"aes": {"secret": None}}
  337. crypticle.return_value = crypt
  338. sync_wrapper.return_value = wrap
  339. # try simple publish with glob tgt_type
  340. channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
  341. payload = wrap.send.call_args[0][0]
  342. # verify we send it without any specific topic
  343. assert "topic_lst" not in payload
  344. # try simple publish with list tgt_type
  345. channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
  346. payload = wrap.send.call_args[0][0]
  347. # verify we send it with correct topic
  348. assert "topic_lst" in payload
  349. self.assertEqual(payload["topic_lst"], ["minion01"])
  350. # try with syndic settings
  351. opts['order_masters'] = True
  352. channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
  353. payload = wrap.send.call_args[0][0]
  354. # verify we send it without topic for syndics
  355. assert "topic_lst" not in payload
  356. @patch('salt.utils.minions.CkMinions.check_minions')
  357. @patch('salt.master.SMaster.secrets')
  358. @patch('salt.crypt.Crypticle')
  359. @patch('salt.utils.asynchronous.SyncWrapper')
  360. def test_publish_filtering_str_list(self, sync_wrapper, crypticle, secrets, check_minions):
  361. opts = self.get_temp_config('master')
  362. opts["sign_pub_messages"] = False
  363. channel = TCPPubServerChannel(opts)
  364. wrap = MagicMock()
  365. crypt = MagicMock()
  366. crypt.dumps.return_value = {"test": "value"}
  367. secrets.return_value = {"aes": {"secret": None}}
  368. crypticle.return_value = crypt
  369. sync_wrapper.return_value = wrap
  370. check_minions.return_value = {"minions": ["minion02"]}
  371. # try simple publish with list tgt_type
  372. channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
  373. payload = wrap.send.call_args[0][0]
  374. # verify we send it with correct topic
  375. assert "topic_lst" in payload
  376. self.assertEqual(payload["topic_lst"], ["minion02"])
  377. # verify it was correctly calling check_minions
  378. check_minions.assert_called_with("minion02", tgt_type="list")