test_ipc.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # -*- coding: utf-8 -*-
  2. '''
  3. :codeauthor: Mike Place <mp@saltstack.com>
  4. '''
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import os
  8. import errno
  9. import socket
  10. import threading
  11. import logging
  12. import tornado.gen
  13. import tornado.ioloop
  14. import tornado.testing
  15. import salt.config
  16. import salt.exceptions
  17. import salt.transport.ipc
  18. import salt.transport.server
  19. import salt.transport.client
  20. import salt.utils.platform
  21. from salt.ext import six
  22. from salt.ext.six.moves import range
  23. # Import Salt Testing libs
  24. from tests.support.runtime import RUNTIME_VARS
  25. from tests.support.mock import MagicMock
  26. from tests.support.unit import skipIf
  27. log = logging.getLogger(__name__)
  28. @skipIf(salt.utils.platform.is_windows(), 'Windows does not support Posix IPC')
  29. class BaseIPCReqCase(tornado.testing.AsyncTestCase):
  30. '''
  31. Test the req server/client pair
  32. '''
  33. def setUp(self):
  34. super(BaseIPCReqCase, self).setUp()
  35. #self._start_handlers = dict(self.io_loop._handlers)
  36. self.socket_path = os.path.join(RUNTIME_VARS.TMP, 'ipc_test.ipc')
  37. self.server_channel = salt.transport.ipc.IPCMessageServer(
  38. self.socket_path,
  39. io_loop=self.io_loop,
  40. payload_handler=self._handle_payload,
  41. )
  42. self.server_channel.start()
  43. self.payloads = []
  44. def tearDown(self):
  45. super(BaseIPCReqCase, self).tearDown()
  46. #failures = []
  47. try:
  48. self.server_channel.close()
  49. except socket.error as exc:
  50. if exc.errno != errno.EBADF:
  51. # If its not a bad file descriptor error, raise
  52. raise
  53. os.unlink(self.socket_path)
  54. #for k, v in six.iteritems(self.io_loop._handlers):
  55. # if self._start_handlers.get(k) != v:
  56. # failures.append((k, v))
  57. #if len(failures) > 0:
  58. # raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  59. del self.payloads
  60. del self.socket_path
  61. del self.server_channel
  62. #del self._start_handlers
  63. @tornado.gen.coroutine
  64. def _handle_payload(self, payload, reply_func):
  65. self.payloads.append(payload)
  66. yield reply_func(payload)
  67. if isinstance(payload, dict) and payload.get('stop'):
  68. self.stop()
  69. class IPCMessageClient(BaseIPCReqCase):
  70. '''
  71. Test all of the clear msg stuff
  72. '''
  73. def _get_channel(self):
  74. if not hasattr(self, 'channel') or self.channel is None:
  75. self.channel = salt.transport.ipc.IPCMessageClient(
  76. socket_path=self.socket_path,
  77. io_loop=self.io_loop,
  78. )
  79. self.channel.connect(callback=self.stop)
  80. self.wait()
  81. return self.channel
  82. def setUp(self):
  83. super(IPCMessageClient, self).setUp()
  84. self.channel = self._get_channel()
  85. def tearDown(self):
  86. super(IPCMessageClient, self).tearDown()
  87. try:
  88. # Make sure we close no matter what we've done in the tests
  89. del self.channel
  90. except socket.error as exc:
  91. if exc.errno != errno.EBADF:
  92. # If its not a bad file descriptor error, raise
  93. raise
  94. finally:
  95. self.channel = None
  96. def test_singleton(self):
  97. channel = self._get_channel()
  98. assert self.channel is channel
  99. # Delete the local channel. Since there's still one more refefence
  100. # __del__ wasn't called
  101. del channel
  102. assert self.channel
  103. msg = {'foo': 'bar', 'stop': True}
  104. self.channel.send(msg)
  105. self.wait()
  106. self.assertEqual(self.payloads[0], msg)
  107. def test_basic_send(self):
  108. msg = {'foo': 'bar', 'stop': True}
  109. self.channel.send(msg)
  110. self.wait()
  111. self.assertEqual(self.payloads[0], msg)
  112. def test_many_send(self):
  113. msgs = []
  114. self.server_channel.stream_handler = MagicMock()
  115. for i in range(0, 1000):
  116. msgs.append('test_many_send_{0}'.format(i))
  117. for i in msgs:
  118. self.channel.send(i)
  119. self.channel.send({'stop': True})
  120. self.wait()
  121. self.assertEqual(self.payloads[:-1], msgs)
  122. def test_very_big_message(self):
  123. long_str = ''.join([six.text_type(num) for num in range(10**5)])
  124. msg = {'long_str': long_str, 'stop': True}
  125. self.channel.send(msg)
  126. self.wait()
  127. self.assertEqual(msg, self.payloads[0])
  128. def test_multistream_sends(self):
  129. local_channel = self._get_channel()
  130. for c in (self.channel, local_channel):
  131. c.send('foo')
  132. self.channel.send({'stop': True})
  133. self.wait()
  134. self.assertEqual(self.payloads[:-1], ['foo', 'foo'])
  135. def test_multistream_errors(self):
  136. local_channel = self._get_channel()
  137. for c in (self.channel, local_channel):
  138. c.send(None)
  139. for c in (self.channel, local_channel):
  140. c.send('foo')
  141. self.channel.send({'stop': True})
  142. self.wait()
  143. self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo'])
  144. @skipIf(salt.utils.platform.is_windows(), 'Windows does not support Posix IPC')
  145. class IPCMessagePubSubCase(tornado.testing.AsyncTestCase):
  146. '''
  147. Test all of the clear msg stuff
  148. '''
  149. def setUp(self):
  150. super(IPCMessagePubSubCase, self).setUp()
  151. self.opts = {'ipc_write_buffer': 0}
  152. self.socket_path = os.path.join(RUNTIME_VARS.TMP, 'ipc_test.ipc')
  153. self.pub_channel = self._get_pub_channel()
  154. self.sub_channel = self._get_sub_channel()
  155. def _get_pub_channel(self):
  156. pub_channel = salt.transport.ipc.IPCMessagePublisher(
  157. self.opts,
  158. self.socket_path,
  159. )
  160. pub_channel.start()
  161. return pub_channel
  162. def _get_sub_channel(self):
  163. sub_channel = salt.transport.ipc.IPCMessageSubscriber(
  164. socket_path=self.socket_path,
  165. io_loop=self.io_loop,
  166. )
  167. sub_channel.connect(callback=self.stop)
  168. self.wait()
  169. return sub_channel
  170. def tearDown(self):
  171. super(IPCMessagePubSubCase, self).tearDown()
  172. try:
  173. self.pub_channel.close()
  174. except socket.error as exc:
  175. if exc.errno != errno.EBADF:
  176. # If its not a bad file descriptor error, raise
  177. raise
  178. try:
  179. self.sub_channel.close()
  180. except socket.error as exc:
  181. if exc.errno != errno.EBADF:
  182. # If its not a bad file descriptor error, raise
  183. raise
  184. os.unlink(self.socket_path)
  185. del self.pub_channel
  186. del self.sub_channel
  187. def test_multi_client_reading(self):
  188. # To be completely fair let's create 2 clients.
  189. client1 = self.sub_channel
  190. client2 = self._get_sub_channel()
  191. call_cnt = []
  192. # Create a watchdog to be safe from hanging in sync loops (what old code did)
  193. evt = threading.Event()
  194. def close_server():
  195. if evt.wait(1):
  196. return
  197. client2.close()
  198. self.stop()
  199. watchdog = threading.Thread(target=close_server)
  200. watchdog.start()
  201. # Runs in ioloop thread so we're safe from race conditions here
  202. def handler(raw):
  203. call_cnt.append(raw)
  204. if len(call_cnt) >= 2:
  205. evt.set()
  206. self.stop()
  207. # Now let both waiting data at once
  208. client1.read_async(handler)
  209. client2.read_async(handler)
  210. self.pub_channel.publish('TEST')
  211. self.wait()
  212. self.assertEqual(len(call_cnt), 2)
  213. self.assertEqual(call_cnt[0], 'TEST')
  214. self.assertEqual(call_cnt[1], 'TEST')
  215. def test_sync_reading(self):
  216. # To be completely fair let's create 2 clients.
  217. client1 = self.sub_channel
  218. client2 = self._get_sub_channel()
  219. call_cnt = []
  220. # Now let both waiting data at once
  221. self.pub_channel.publish('TEST')
  222. ret1 = client1.read_sync()
  223. ret2 = client2.read_sync()
  224. self.assertEqual(ret1, 'TEST')
  225. self.assertEqual(ret2, 'TEST')