test_tcp.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # -*- coding: utf-8 -*-
  2. '''
  3. :codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
  4. '''
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import threading
  8. import socket
  9. import logging
  10. import tornado.gen
  11. import tornado.ioloop
  12. import tornado.concurrent
  13. from tornado.testing import AsyncTestCase, gen_test
  14. import salt.config
  15. from salt.ext import six
  16. import salt.utils.platform
  17. import salt.utils.process
  18. import salt.transport.server
  19. import salt.transport.client
  20. import salt.exceptions
  21. from salt.ext.six.moves import range
  22. from salt.transport.tcp import SaltMessageClientPool, SaltMessageClient
  23. # Import Salt Testing libs
  24. from tests.support.unit import TestCase, skipIf
  25. from tests.support.helpers import get_unused_localhost_port, flaky
  26. from tests.support.mixins import AdaptedConfigurationTestCaseMixin
  27. from tests.support.mock import MagicMock, patch
  28. from tests.unit.transport.mixins import PubChannelMixin, ReqChannelMixin, run_loop_in_thread
  29. log = logging.getLogger(__name__)
  30. class BaseTCPReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
  31. '''
  32. Test the req server/client pair
  33. '''
  34. @classmethod
  35. def setUpClass(cls):
  36. if not hasattr(cls, '_handle_payload'):
  37. return
  38. ret_port = get_unused_localhost_port()
  39. publish_port = get_unused_localhost_port()
  40. tcp_master_pub_port = get_unused_localhost_port()
  41. tcp_master_pull_port = get_unused_localhost_port()
  42. tcp_master_publish_pull = get_unused_localhost_port()
  43. tcp_master_workers = get_unused_localhost_port()
  44. cls.master_config = cls.get_temp_config(
  45. 'master',
  46. **{'transport': 'tcp',
  47. 'auto_accept': True,
  48. 'ret_port': ret_port,
  49. 'publish_port': publish_port,
  50. 'tcp_master_pub_port': tcp_master_pub_port,
  51. 'tcp_master_pull_port': tcp_master_pull_port,
  52. 'tcp_master_publish_pull': tcp_master_publish_pull,
  53. 'tcp_master_workers': tcp_master_workers}
  54. )
  55. cls.minion_config = cls.get_temp_config(
  56. 'minion',
  57. **{'transport': 'tcp',
  58. 'master_ip': '127.0.0.1',
  59. 'master_port': ret_port,
  60. 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
  61. )
  62. cls.process_manager = salt.utils.process.ProcessManager(name='ReqServer_ProcessManager')
  63. cls.server_channel = salt.transport.server.ReqServerChannel.factory(cls.master_config)
  64. cls.server_channel.pre_fork(cls.process_manager)
  65. cls.io_loop = tornado.ioloop.IOLoop()
  66. cls.stop = threading.Event()
  67. cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  68. cls.server_thread = threading.Thread(
  69. target=run_loop_in_thread,
  70. args=(cls.io_loop, cls.stop,),
  71. )
  72. cls.server_thread.start()
  73. @classmethod
  74. def tearDownClass(cls):
  75. cls.server_channel.close()
  76. cls.stop.set()
  77. cls.server_thread.join()
  78. cls.process_manager.kill_children()
  79. del cls.server_channel
  80. @classmethod
  81. @tornado.gen.coroutine
  82. def _handle_payload(cls, payload):
  83. '''
  84. TODO: something besides echo
  85. '''
  86. raise tornado.gen.Return((payload, {'fun': 'send_clear'}))
  87. @skipIf(salt.utils.platform.is_darwin(), 'hanging test suite on MacOS')
  88. class ClearReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  89. '''
  90. Test all of the clear msg stuff
  91. '''
  92. def setUp(self):
  93. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config, crypt='clear')
  94. def tearDown(self):
  95. self.channel.close()
  96. del self.channel
  97. @classmethod
  98. @tornado.gen.coroutine
  99. def _handle_payload(cls, payload):
  100. '''
  101. TODO: something besides echo
  102. '''
  103. raise tornado.gen.Return((payload, {'fun': 'send_clear'}))
  104. @skipIf(salt.utils.platform.is_darwin(), 'hanging test suite on MacOS')
  105. class AESReqTestCases(BaseTCPReqCase, ReqChannelMixin):
  106. def setUp(self):
  107. self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
  108. def tearDown(self):
  109. self.channel.close()
  110. del self.channel
  111. @classmethod
  112. @tornado.gen.coroutine
  113. def _handle_payload(cls, payload):
  114. '''
  115. TODO: something besides echo
  116. '''
  117. raise tornado.gen.Return((payload, {'fun': 'send'}))
  118. # TODO: make failed returns have a specific framing so we can raise the same exception
  119. # on encrypted channels
  120. @flaky
  121. def test_badload(self):
  122. '''
  123. Test a variety of bad requests, make sure that we get some sort of error
  124. '''
  125. msgs = ['', [], tuple()]
  126. for msg in msgs:
  127. with self.assertRaises(salt.exceptions.AuthenticationError):
  128. ret = self.channel.send(msg)
  129. class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
  130. '''
  131. Test the req server/client pair
  132. '''
  133. @classmethod
  134. def setUpClass(cls):
  135. ret_port = get_unused_localhost_port()
  136. publish_port = get_unused_localhost_port()
  137. tcp_master_pub_port = get_unused_localhost_port()
  138. tcp_master_pull_port = get_unused_localhost_port()
  139. tcp_master_publish_pull = get_unused_localhost_port()
  140. tcp_master_workers = get_unused_localhost_port()
  141. cls.master_config = cls.get_temp_config(
  142. 'master',
  143. **{'transport': 'tcp',
  144. 'auto_accept': True,
  145. 'ret_port': ret_port,
  146. 'publish_port': publish_port,
  147. 'tcp_master_pub_port': tcp_master_pub_port,
  148. 'tcp_master_pull_port': tcp_master_pull_port,
  149. 'tcp_master_publish_pull': tcp_master_publish_pull,
  150. 'tcp_master_workers': tcp_master_workers}
  151. )
  152. cls.minion_config = cls.get_temp_config(
  153. 'minion',
  154. **{'transport': 'tcp',
  155. 'master_ip': '127.0.0.1',
  156. 'auth_timeout': 1,
  157. 'master_port': ret_port,
  158. 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
  159. )
  160. cls.process_manager = salt.utils.process.ProcessManager(name='ReqServer_ProcessManager')
  161. cls.server_channel = salt.transport.server.PubServerChannel.factory(cls.master_config)
  162. cls.server_channel.pre_fork(cls.process_manager)
  163. # we also require req server for auth
  164. cls.req_server_channel = salt.transport.server.ReqServerChannel.factory(cls.master_config)
  165. cls.req_server_channel.pre_fork(cls.process_manager)
  166. cls.io_loop = tornado.ioloop.IOLoop()
  167. cls.stop = threading.Event()
  168. cls.req_server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
  169. cls.server_thread = threading.Thread(
  170. target=run_loop_in_thread,
  171. args=(cls.io_loop, cls.stop,),
  172. )
  173. cls.server_thread.start()
  174. @classmethod
  175. def _handle_payload(cls, payload):
  176. '''
  177. TODO: something besides echo
  178. '''
  179. return payload, {'fun': 'send_clear'}
  180. @classmethod
  181. def tearDownClass(cls):
  182. cls.req_server_channel.close()
  183. cls.server_channel.close()
  184. cls.stop.set()
  185. cls.server_thread.join()
  186. cls.process_manager.kill_children()
  187. del cls.req_server_channel
  188. def setUp(self):
  189. super(BaseTCPPubCase, self).setUp()
  190. self._start_handlers = dict(self.io_loop._handlers)
  191. def tearDown(self):
  192. super(BaseTCPPubCase, self).tearDown()
  193. failures = []
  194. for k, v in six.iteritems(self.io_loop._handlers):
  195. if self._start_handlers.get(k) != v:
  196. failures.append((k, v))
  197. if len(failures) > 0:
  198. raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  199. del self.channel
  200. del self._start_handlers
  201. @skipIf(True, 'Skip until we can devote time to fix this test')
  202. class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
  203. '''
  204. Tests around the publish system
  205. '''
  206. class SaltMessageClientPoolTest(AsyncTestCase):
  207. def setUp(self):
  208. super(SaltMessageClientPoolTest, self).setUp()
  209. sock_pool_size = 5
  210. with patch('salt.transport.tcp.SaltMessageClient.__init__', MagicMock(return_value=None)):
  211. self.message_client_pool = SaltMessageClientPool({'sock_pool_size': sock_pool_size},
  212. args=({}, '', 0))
  213. self.original_message_clients = self.message_client_pool.message_clients
  214. self.message_client_pool.message_clients = [MagicMock() for _ in range(sock_pool_size)]
  215. def tearDown(self):
  216. with patch('salt.transport.tcp.SaltMessageClient.close', MagicMock(return_value=None)):
  217. del self.original_message_clients
  218. super(SaltMessageClientPoolTest, self).tearDown()
  219. def test_send(self):
  220. for message_client_mock in self.message_client_pool.message_clients:
  221. message_client_mock.send_queue = [0, 0, 0]
  222. message_client_mock.send.return_value = []
  223. self.assertEqual([], self.message_client_pool.send())
  224. self.message_client_pool.message_clients[2].send_queue = [0]
  225. self.message_client_pool.message_clients[2].send.return_value = [1]
  226. self.assertEqual([1], self.message_client_pool.send())
  227. def test_write_to_stream(self):
  228. for message_client_mock in self.message_client_pool.message_clients:
  229. message_client_mock.send_queue = [0, 0, 0]
  230. message_client_mock._stream.write.return_value = []
  231. self.assertEqual([], self.message_client_pool.write_to_stream(''))
  232. self.message_client_pool.message_clients[2].send_queue = [0]
  233. self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
  234. self.assertEqual([1], self.message_client_pool.write_to_stream(''))
  235. def test_close(self):
  236. self.message_client_pool.close()
  237. self.assertEqual([], self.message_client_pool.message_clients)
  238. def test_on_recv(self):
  239. for message_client_mock in self.message_client_pool.message_clients:
  240. message_client_mock.on_recv.return_value = None
  241. self.message_client_pool.on_recv()
  242. for message_client_mock in self.message_client_pool.message_clients:
  243. self.assertTrue(message_client_mock.on_recv.called)
  244. def test_connect_all(self):
  245. @gen_test
  246. def test_connect(self):
  247. yield self.message_client_pool.connect()
  248. for message_client_mock in self.message_client_pool.message_clients:
  249. future = tornado.concurrent.Future()
  250. future.set_result('foo')
  251. message_client_mock.connect.return_value = future
  252. self.assertIsNone(test_connect(self))
  253. def test_connect_partial(self):
  254. @gen_test(timeout=0.1)
  255. def test_connect(self):
  256. yield self.message_client_pool.connect()
  257. for idx, message_client_mock in enumerate(self.message_client_pool.message_clients):
  258. future = tornado.concurrent.Future()
  259. if idx % 2 == 0:
  260. future.set_result('foo')
  261. message_client_mock.connect.return_value = future
  262. with self.assertRaises(tornado.ioloop.TimeoutError):
  263. test_connect(self)
  264. class SaltMessageClientCleanupTest(TestCase, AdaptedConfigurationTestCaseMixin):
  265. def setUp(self):
  266. self.listen_on = '127.0.0.1'
  267. self.port = get_unused_localhost_port()
  268. self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  269. self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  270. self.sock.bind((self.listen_on, self.port))
  271. self.sock.listen(1)
  272. def tearDown(self):
  273. self.sock.close()
  274. del self.sock
  275. def test_message_client(self):
  276. '''
  277. test message client cleanup on close
  278. '''
  279. orig_loop = tornado.ioloop.IOLoop()
  280. orig_loop.make_current()
  281. opts = self.get_temp_config('master')
  282. client = SaltMessageClient(opts, self.listen_on, self.port)
  283. # Mock the io_loop's stop method so we know when it has been called.
  284. orig_loop.real_stop = orig_loop.stop
  285. orig_loop.stop_called = False
  286. def stop(*args, **kwargs):
  287. orig_loop.stop_called = True
  288. orig_loop.real_stop()
  289. orig_loop.stop = stop
  290. try:
  291. assert client.io_loop == orig_loop
  292. client.io_loop.run_sync(client.connect)
  293. # Ensure we are testing the _read_until_future and io_loop teardown
  294. assert client._stream is not None
  295. assert client._read_until_future is not None
  296. assert orig_loop.stop_called is True
  297. # The run_sync call will set stop_called, reset it
  298. orig_loop.stop_called = False
  299. client.close()
  300. # Stop should be called again, client's io_loop should be None
  301. assert orig_loop.stop_called is True
  302. assert client.io_loop is None
  303. finally:
  304. orig_loop.stop = orig_loop.real_stop
  305. del orig_loop.real_stop
  306. del orig_loop.stop_called