test_ipc.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # -*- coding: utf-8 -*-
  2. '''
  3. :codeauthor: Mike Place <mp@saltstack.com>
  4. '''
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import os
  8. import errno
  9. import socket
  10. import threading
  11. import logging
  12. import tornado.gen
  13. import tornado.ioloop
  14. import tornado.testing
  15. import salt.config
  16. import salt.exceptions
  17. import salt.transport.ipc
  18. import salt.transport.server
  19. import salt.transport.client
  20. import salt.utils.platform
  21. from salt.ext import six
  22. from salt.ext.six.moves import range
  23. # Import Salt Testing libs
  24. from tests.support.runtests import RUNTIME_VARS
  25. from tests.support.mock import MagicMock
  26. from tests.support.unit import skipIf
  27. log = logging.getLogger(__name__)
  28. @skipIf(salt.utils.platform.is_windows(), 'Windows does not support Posix IPC')
  29. class BaseIPCReqCase(tornado.testing.AsyncTestCase):
  30. '''
  31. Test the req server/client pair
  32. '''
  33. def setUp(self):
  34. super(BaseIPCReqCase, self).setUp()
  35. #self._start_handlers = dict(self.io_loop._handlers)
  36. self.socket_path = os.path.join(RUNTIME_VARS.TMP, 'ipc_test.ipc')
  37. self.server_channel = salt.transport.ipc.IPCMessageServer(
  38. salt.config.master_config(None),
  39. self.socket_path,
  40. io_loop=self.io_loop,
  41. payload_handler=self._handle_payload,
  42. )
  43. self.server_channel.start()
  44. self.payloads = []
  45. def tearDown(self):
  46. super(BaseIPCReqCase, self).tearDown()
  47. #failures = []
  48. try:
  49. self.server_channel.close()
  50. except socket.error as exc:
  51. if exc.errno != errno.EBADF:
  52. # If its not a bad file descriptor error, raise
  53. raise
  54. os.unlink(self.socket_path)
  55. #for k, v in six.iteritems(self.io_loop._handlers):
  56. # if self._start_handlers.get(k) != v:
  57. # failures.append((k, v))
  58. #if len(failures) > 0:
  59. # raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  60. del self.payloads
  61. del self.socket_path
  62. del self.server_channel
  63. #del self._start_handlers
  64. @tornado.gen.coroutine
  65. def _handle_payload(self, payload, reply_func):
  66. self.payloads.append(payload)
  67. yield reply_func(payload)
  68. if isinstance(payload, dict) and payload.get('stop'):
  69. self.stop()
  70. class IPCMessageClient(BaseIPCReqCase):
  71. '''
  72. Test all of the clear msg stuff
  73. '''
  74. def _get_channel(self):
  75. if not hasattr(self, 'channel') or self.channel is None:
  76. self.channel = salt.transport.ipc.IPCMessageClient(
  77. socket_path=self.socket_path,
  78. io_loop=self.io_loop,
  79. )
  80. self.channel.connect(callback=self.stop)
  81. self.wait()
  82. return self.channel
  83. def setUp(self):
  84. super(IPCMessageClient, self).setUp()
  85. self.channel = self._get_channel()
  86. def tearDown(self):
  87. super(IPCMessageClient, self).tearDown()
  88. try:
  89. # Make sure we close no matter what we've done in the tests
  90. del self.channel
  91. except socket.error as exc:
  92. if exc.errno != errno.EBADF:
  93. # If its not a bad file descriptor error, raise
  94. raise
  95. finally:
  96. self.channel = None
  97. def test_singleton(self):
  98. channel = self._get_channel()
  99. assert self.channel is channel
  100. # Delete the local channel. Since there's still one more refefence
  101. # __del__ wasn't called
  102. del channel
  103. assert self.channel
  104. msg = {'foo': 'bar', 'stop': True}
  105. self.channel.send(msg)
  106. self.wait()
  107. self.assertEqual(self.payloads[0], msg)
  108. def test_basic_send(self):
  109. msg = {'foo': 'bar', 'stop': True}
  110. self.channel.send(msg)
  111. self.wait()
  112. self.assertEqual(self.payloads[0], msg)
  113. def test_many_send(self):
  114. msgs = []
  115. self.server_channel.stream_handler = MagicMock()
  116. for i in range(0, 1000):
  117. msgs.append('test_many_send_{0}'.format(i))
  118. for i in msgs:
  119. self.channel.send(i)
  120. self.channel.send({'stop': True})
  121. self.wait()
  122. self.assertEqual(self.payloads[:-1], msgs)
  123. def test_very_big_message(self):
  124. long_str = ''.join([six.text_type(num) for num in range(10**5)])
  125. msg = {'long_str': long_str, 'stop': True}
  126. self.channel.send(msg)
  127. self.wait()
  128. self.assertEqual(msg, self.payloads[0])
  129. def test_multistream_sends(self):
  130. local_channel = self._get_channel()
  131. for c in (self.channel, local_channel):
  132. c.send('foo')
  133. self.channel.send({'stop': True})
  134. self.wait()
  135. self.assertEqual(self.payloads[:-1], ['foo', 'foo'])
  136. def test_multistream_errors(self):
  137. local_channel = self._get_channel()
  138. for c in (self.channel, local_channel):
  139. c.send(None)
  140. for c in (self.channel, local_channel):
  141. c.send('foo')
  142. self.channel.send({'stop': True})
  143. self.wait()
  144. self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo'])
  145. @skipIf(salt.utils.platform.is_windows(), 'Windows does not support Posix IPC')
  146. class IPCMessagePubSubCase(tornado.testing.AsyncTestCase):
  147. '''
  148. Test all of the clear msg stuff
  149. '''
  150. def setUp(self):
  151. super(IPCMessagePubSubCase, self).setUp()
  152. self.opts = {
  153. 'ipc_write_buffer': 0,
  154. 'ipc_so_backlog': 128,
  155. }
  156. self.socket_path = os.path.join(RUNTIME_VARS.TMP, 'ipc_test.ipc')
  157. self.pub_channel = self._get_pub_channel()
  158. self.sub_channel = self._get_sub_channel()
  159. def _get_pub_channel(self):
  160. pub_channel = salt.transport.ipc.IPCMessagePublisher(
  161. self.opts,
  162. self.socket_path,
  163. )
  164. pub_channel.start()
  165. return pub_channel
  166. def _get_sub_channel(self):
  167. sub_channel = salt.transport.ipc.IPCMessageSubscriber(
  168. socket_path=self.socket_path,
  169. io_loop=self.io_loop,
  170. )
  171. sub_channel.connect(callback=self.stop)
  172. self.wait()
  173. return sub_channel
  174. def tearDown(self):
  175. super(IPCMessagePubSubCase, self).tearDown()
  176. try:
  177. self.pub_channel.close()
  178. except socket.error as exc:
  179. if exc.errno != errno.EBADF:
  180. # If its not a bad file descriptor error, raise
  181. raise
  182. try:
  183. self.sub_channel.close()
  184. except socket.error as exc:
  185. if exc.errno != errno.EBADF:
  186. # If its not a bad file descriptor error, raise
  187. raise
  188. os.unlink(self.socket_path)
  189. del self.pub_channel
  190. del self.sub_channel
  191. def test_multi_client_reading(self):
  192. # To be completely fair let's create 2 clients.
  193. client1 = self.sub_channel
  194. client2 = self._get_sub_channel()
  195. call_cnt = []
  196. # Create a watchdog to be safe from hanging in sync loops (what old code did)
  197. evt = threading.Event()
  198. def close_server():
  199. if evt.wait(1):
  200. return
  201. client2.close()
  202. self.stop()
  203. watchdog = threading.Thread(target=close_server)
  204. watchdog.start()
  205. # Runs in ioloop thread so we're safe from race conditions here
  206. def handler(raw):
  207. call_cnt.append(raw)
  208. if len(call_cnt) >= 2:
  209. evt.set()
  210. self.stop()
  211. # Now let both waiting data at once
  212. client1.callbacks.add(handler)
  213. client2.callbacks.add(handler)
  214. client1.read_async()
  215. client2.read_async()
  216. self.pub_channel.publish('TEST')
  217. self.wait()
  218. self.assertEqual(len(call_cnt), 2)
  219. self.assertEqual(call_cnt[0], 'TEST')
  220. self.assertEqual(call_cnt[1], 'TEST')
  221. def test_sync_reading(self):
  222. # To be completely fair let's create 2 clients.
  223. client1 = self.sub_channel
  224. client2 = self._get_sub_channel()
  225. call_cnt = []
  226. # Now let both waiting data at once
  227. self.pub_channel.publish('TEST')
  228. ret1 = client1.read_sync()
  229. ret2 = client2.read_sync()
  230. self.assertEqual(ret1, 'TEST')
  231. self.assertEqual(ret2, 'TEST')