test_tcp.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # -*- coding: utf-8 -*-
  2. '''
  3. :codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
  4. '''
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import threading
  8. import socket
  9. import logging
  10. import tornado.gen
  11. import tornado.ioloop
  12. import tornado.concurrent
  13. from tornado.testing import AsyncTestCase, gen_test
  14. import salt.config
  15. from salt.ext import six
  16. import salt.utils.platform
  17. import salt.utils.process
  18. import salt.transport.server
  19. import salt.transport.client
  20. import salt.exceptions
  21. from salt.ext.six.moves import range
  22. from salt.transport.tcp import SaltMessageClientPool, SaltMessageClient
  23. # Import Salt Testing libs
  24. from tests.support.unit import TestCase, skipIf
  25. from tests.support.helpers import get_unused_localhost_port, flaky
  26. from tests.support.mixins import AdaptedConfigurationTestCaseMixin
  27. from tests.support.mock import MagicMock, patch
  28. from tests.unit.transport.mixins import PubChannelMixin, ReqChannelMixin, run_loop_in_thread
  29. log = logging.getLogger(__name__)
  30. class BaseTCPReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
  31. '''
  32. Test the req server/client pair
  33. '''
  34. @classmethod
  35. def setUpClass(cls):
  36. if not hasattr(cls, '_handle_payload'):
  37. return
  38. ret_port = get_unused_localhost_port()
  39. publish_port = get_unused_localhost_port()
  40. tcp_master_pub_port = get_unused_localhost_port()
  41. tcp_master_pull_port = get_unused_localhost_port()
  42. tcp_master_publish_pull = get_unused_localhost_port()
  43. tcp_master_workers = get_unused_localhost_port()
  44. cls.master_config = cls.get_temp_config(
  45. 'master',
  46. **{'transport': 'tcp',
  47. 'auto_accept': True,
  48. 'ret_port': ret_port,
  49. 'publish_port': publish_port,
  50. 'tcp_master_pub_port': tcp_master_pub_port,
  51. 'tcp_master_pull_port': tcp_master_pull_port,
  52. 'tcp_master_publish_pull': tcp_master_publish_pull,
  53. 'tcp_master_workers': tcp_master_workers}
  54. )
  55. cls.minion_config = cls.get_temp_config(
  56. 'minion',
  57. **{'transport': 'tcp',
  58. 'master_ip': '127.0.0.1',
  59. 'master_port': ret_port,
  60. 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
  61. )
  62. cls.process_manager = salt.utils.process.ProcessManager(name='ReqServer_ProcessManager')
  63. cls.server_channel = salt.transport.server.ReqServerChannel.factory(cls.master_config)
  64. cls.server_channel.pre_fork(cls.process_manager)
  65. cls.io_loop = tornado.ioloop.IOLoop()
  66. cls.stop = threading.Event()
  67. cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  68. cls.server_thread = threading.Thread(
  69. target=run_loop_in_thread,
  70. args=(cls.io_loop, cls.stop,),
  71. )
  72. cls.server_thread.start()
  73. @classmethod
  74. def tearDownClass(cls):
  75. cls.server_channel.close()
  76. cls.stop.set()
  77. cls.server_thread.join()
  78. cls.process_manager.kill_children()
  79. del cls.server_channel
  80. @classmethod
  81. @tornado.gen.coroutine
  82. def _handle_payload(cls, payload):
  83. '''
  84. TODO: something besides echo
  85. '''
  86. raise tornado.gen.Return((payload, {'fun': 'send_clear'}))
  87. @skipIf(salt.utils.platform.is_darwin(), 'hanging test suite on MacOS')
  88. class ClearReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  89. '''
  90. Test all of the clear msg stuff
  91. '''
  92. def setUp(self):
  93. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config, crypt='clear')
  94. def tearDown(self):
  95. del self.channel
  96. @classmethod
  97. @tornado.gen.coroutine
  98. def _handle_payload(cls, payload):
  99. '''
  100. TODO: something besides echo
  101. '''
  102. raise tornado.gen.Return((payload, {'fun': 'send_clear'}))
  103. @skipIf(salt.utils.platform.is_darwin(), 'hanging test suite on MacOS')
  104. class AESReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  105. def setUp(self):
  106. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
  107. def tearDown(self):
  108. del self.channel
  109. @classmethod
  110. @tornado.gen.coroutine
  111. def _handle_payload(cls, payload):
  112. '''
  113. TODO: something besides echo
  114. '''
  115. raise tornado.gen.Return((payload, {'fun': 'send'}))
  116. # TODO: make failed returns have a specific framing so we can raise the same exception
  117. # on encrypted channels
  118. @flaky
  119. def test_badload(self):
  120. '''
  121. Test a variety of bad requests, make sure that we get some sort of error
  122. '''
  123. msgs = ['', [], tuple()]
  124. for msg in msgs:
  125. with self.assertRaises(salt.exceptions.AuthenticationError):
  126. ret = self.channel.send(msg)
  127. class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
  128. '''
  129. Test the req server/client pair
  130. '''
  131. @classmethod
  132. def setUpClass(cls):
  133. ret_port = get_unused_localhost_port()
  134. publish_port = get_unused_localhost_port()
  135. tcp_master_pub_port = get_unused_localhost_port()
  136. tcp_master_pull_port = get_unused_localhost_port()
  137. tcp_master_publish_pull = get_unused_localhost_port()
  138. tcp_master_workers = get_unused_localhost_port()
  139. cls.master_config = cls.get_temp_config(
  140. 'master',
  141. **{'transport': 'tcp',
  142. 'auto_accept': True,
  143. 'ret_port': ret_port,
  144. 'publish_port': publish_port,
  145. 'tcp_master_pub_port': tcp_master_pub_port,
  146. 'tcp_master_pull_port': tcp_master_pull_port,
  147. 'tcp_master_publish_pull': tcp_master_publish_pull,
  148. 'tcp_master_workers': tcp_master_workers}
  149. )
  150. cls.minion_config = cls.get_temp_config(
  151. 'minion',
  152. **{'transport': 'tcp',
  153. 'master_ip': '127.0.0.1',
  154. 'auth_timeout': 1,
  155. 'master_port': ret_port,
  156. 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
  157. )
  158. cls.process_manager = salt.utils.process.ProcessManager(name='ReqServer_ProcessManager')
  159. cls.server_channel = salt.transport.server.PubServerChannel.factory(cls.master_config)
  160. cls.server_channel.pre_fork(cls.process_manager)
  161. # we also require req server for auth
  162. cls.req_server_channel = salt.transport.server.ReqServerChannel.factory(cls.master_config)
  163. cls.req_server_channel.pre_fork(cls.process_manager)
  164. cls.io_loop = tornado.ioloop.IOLoop()
  165. cls.stop = threading.Event()
  166. cls.req_server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  167. cls.server_thread = threading.Thread(
  168. target=run_loop_in_thread,
  169. args=(cls.io_loop, cls.stop,),
  170. )
  171. cls.server_thread.start()
  172. @classmethod
  173. def _handle_payload(cls, payload):
  174. '''
  175. TODO: something besides echo
  176. '''
  177. return payload, {'fun': 'send_clear'}
  178. @classmethod
  179. def tearDownClass(cls):
  180. cls.req_server_channel.close()
  181. cls.server_channel.close()
  182. cls.stop.set()
  183. cls.server_thread.join()
  184. cls.process_manager.kill_children()
  185. del cls.req_server_channel
  186. def setUp(self):
  187. super(BaseTCPPubCase, self).setUp()
  188. self._start_handlers = dict(self.io_loop._handlers)
  189. def tearDown(self):
  190. super(BaseTCPPubCase, self).tearDown()
  191. failures = []
  192. for k, v in six.iteritems(self.io_loop._handlers):
  193. if self._start_handlers.get(k) != v:
  194. failures.append((k, v))
  195. if len(failures) > 0:
  196. raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  197. del self.channel
  198. del self._start_handlers
  199. @skipIf(True, 'Skip until we can devote time to fix this test')
  200. class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
  201. '''
  202. Tests around the publish system
  203. '''
  204. class SaltMessageClientPoolTest(AsyncTestCase):
  205. def setUp(self):
  206. super(SaltMessageClientPoolTest, self).setUp()
  207. sock_pool_size = 5
  208. with patch('salt.transport.tcp.SaltMessageClient.__init__', MagicMock(return_value=None)):
  209. self.message_client_pool = SaltMessageClientPool({'sock_pool_size': sock_pool_size},
  210. args=({}, '', 0))
  211. self.original_message_clients = self.message_client_pool.message_clients
  212. self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
  213. def tearDown(self):
  214. with patch('salt.transport.tcp.SaltMessageClient.close', MagicMock(return_value=None)):
  215. del self.original_message_clients
  216. super(SaltMessageClientPoolTest, self).tearDown()
  217. def test_send(self):
  218. for message_client_mock in self.message_client_pool.message_clients:
  219. message_client_mock.send_queue = [0, 0, 0]
  220. message_client_mock.send.return_value = []
  221. self.assertEqual([], self.message_client_pool.send())
  222. self.message_client_pool.message_clients[2].send_queue = [0]
  223. self.message_client_pool.message_clients[2].send.return_value = [1]
  224. self.assertEqual([1], self.message_client_pool.send())
  225. def test_write_to_stream(self):
  226. for message_client_mock in self.message_client_pool.message_clients:
  227. message_client_mock.send_queue = [0, 0, 0]
  228. message_client_mock._stream.write.return_value = []
  229. self.assertEqual([], self.message_client_pool.write_to_stream(''))
  230. self.message_client_pool.message_clients[2].send_queue = [0]
  231. self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
  232. self.assertEqual([1], self.message_client_pool.write_to_stream(''))
  233. def test_close(self):
  234. self.message_client_pool.close()
  235. self.assertEqual([], self.message_client_pool.message_clients)
  236. def test_on_recv(self):
  237. for message_client_mock in self.message_client_pool.message_clients:
  238. message_client_mock.on_recv.return_value = None
  239. self.message_client_pool.on_recv()
  240. for message_client_mock in self.message_client_pool.message_clients:
  241. self.assertTrue(message_client_mock.on_recv.called)
  242. def test_connect_all(self):
  243. @gen_test
  244. def test_connect(self):
  245. yield self.message_client_pool.connect()
  246. for message_client_mock in self.message_client_pool.message_clients:
  247. future = tornado.concurrent.Future()
  248. future.set_result('foo')
  249. message_client_mock.connect.return_value = future
  250. self.assertIsNone(test_connect(self))
  251. def test_connect_partial(self):
  252. @gen_test(timeout=0.1)
  253. def test_connect(self):
  254. yield self.message_client_pool.connect()
  255. for idx, message_client_mock in enumerate(self.message_client_pool.message_clients):
  256. future = tornado.concurrent.Future()
  257. if idx % 2 == 0:
  258. future.set_result('foo')
  259. message_client_mock.connect.return_value = future
  260. with self.assertRaises(tornado.ioloop.TimeoutError):
  261. test_connect(self)
  262. class SaltMessageClientCleanupTest(TestCase, AdaptedConfigurationTestCaseMixin):
  263. def setUp(self):
  264. self.listen_on = '127.0.0.1'
  265. self.port = get_unused_localhost_port()
  266. self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  267. self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  268. self.sock.bind((self.listen_on, self.port))
  269. self.sock.listen(1)
  270. def tearDown(self):
  271. self.sock.close()
  272. del self.sock
  273. def test_message_client(self):
  274. '''
  275. test message client cleanup on close
  276. '''
  277. orig_loop = tornado.ioloop.IOLoop()
  278. orig_loop.make_current()
  279. opts = self.get_temp_config('master')
  280. client = SaltMessageClient(opts, self.listen_on, self.port)
  281. # Mock the io_loop's stop method so we know when it has been called.
  282. orig_loop.real_stop = orig_loop.stop
  283. orig_loop.stop_called = False
  284. def stop(*args, **kwargs):
  285. orig_loop.stop_called = True
  286. orig_loop.real_stop()
  287. orig_loop.stop = stop
  288. try:
  289. assert client.io_loop == orig_loop
  290. client.io_loop.run_sync(client.connect)
  291. # Ensure we are testing the _read_until_future and io_loop teardown
  292. assert client._stream is not None
  293. assert client._read_until_future is not None
  294. assert orig_loop.stop_called is True
  295. # The run_sync call will set stop_called, reset it
  296. orig_loop.stop_called = False
  297. client.close()
  298. # Stop should be called again, client's io_loop should be None
  299. assert orig_loop.stop_called is True
  300. assert client.io_loop is None
  301. finally:
  302. orig_loop.stop = orig_loop.real_stop
  303. del orig_loop.real_stop
  304. del orig_loop.stop_called