test_ipc.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """
  2. :codeauthor: Mike Place <mp@saltstack.com>
  3. """
  4. import errno
  5. import logging
  6. import os
  7. import threading
  8. import salt.config
  9. import salt.exceptions
  10. import salt.ext.tornado.gen
  11. import salt.ext.tornado.ioloop
  12. import salt.ext.tornado.testing
  13. import salt.transport.client
  14. import salt.transport.ipc
  15. import salt.transport.server
  16. import salt.utils.platform
  17. from tests.support.mock import MagicMock
  18. from tests.support.runtests import RUNTIME_VARS
  19. from tests.support.unit import skipIf
  20. log = logging.getLogger(__name__)
  21. @skipIf(salt.utils.platform.is_windows(), "Windows does not support Posix IPC")
  22. class BaseIPCReqCase(salt.ext.tornado.testing.AsyncTestCase):
  23. """
  24. Test the req server/client pair
  25. """
  26. def setUp(self):
  27. super().setUp()
  28. # self._start_handlers = dict(self.io_loop._handlers)
  29. self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
  30. self.server_channel = salt.transport.ipc.IPCMessageServer(
  31. self.socket_path,
  32. io_loop=self.io_loop,
  33. payload_handler=self._handle_payload,
  34. )
  35. self.server_channel.start()
  36. self.payloads = []
  37. def tearDown(self):
  38. super().tearDown()
  39. # failures = []
  40. try:
  41. self.server_channel.close()
  42. except OSError as exc:
  43. if exc.errno != errno.EBADF:
  44. # If its not a bad file descriptor error, raise
  45. raise
  46. os.unlink(self.socket_path)
  47. # for k, v in self.io_loop._handlers.items():
  48. # if self._start_handlers.get(k) != v:
  49. # failures.append((k, v))
  50. # if len(failures) > 0:
  51. # raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  52. del self.payloads
  53. del self.socket_path
  54. del self.server_channel
  55. # del self._start_handlers
  56. @salt.ext.tornado.gen.coroutine
  57. def _handle_payload(self, payload, reply_func):
  58. self.payloads.append(payload)
  59. yield reply_func(payload)
  60. if isinstance(payload, dict) and payload.get("stop"):
  61. self.stop()
  62. class IPCMessageClient(BaseIPCReqCase):
  63. """
  64. Test all of the clear msg stuff
  65. """
  66. def _get_channel(self):
  67. if not hasattr(self, "channel") or self.channel is None:
  68. self.channel = salt.transport.ipc.IPCMessageClient(
  69. socket_path=self.socket_path, io_loop=self.io_loop,
  70. )
  71. self.channel.connect(callback=self.stop)
  72. self.wait()
  73. return self.channel
  74. def setUp(self):
  75. super().setUp()
  76. self.channel = self._get_channel()
  77. def tearDown(self):
  78. super().tearDown()
  79. try:
  80. # Make sure we close no matter what we've done in the tests
  81. del self.channel
  82. except OSError as exc:
  83. if exc.errno != errno.EBADF:
  84. # If its not a bad file descriptor error, raise
  85. raise
  86. finally:
  87. self.channel = None
  88. def test_singleton(self):
  89. channel = self._get_channel()
  90. assert self.channel is channel
  91. # Delete the local channel. Since there's still one more refefence
  92. # __del__ wasn't called
  93. del channel
  94. assert self.channel
  95. msg = {"foo": "bar", "stop": True}
  96. self.channel.send(msg)
  97. self.wait()
  98. self.assertEqual(self.payloads[0], msg)
  99. def test_basic_send(self):
  100. msg = {"foo": "bar", "stop": True}
  101. self.channel.send(msg)
  102. self.wait()
  103. self.assertEqual(self.payloads[0], msg)
  104. def test_many_send(self):
  105. msgs = []
  106. self.server_channel.stream_handler = MagicMock()
  107. for i in range(0, 1000):
  108. msgs.append("test_many_send_{}".format(i))
  109. for i in msgs:
  110. self.channel.send(i)
  111. self.channel.send({"stop": True})
  112. self.wait()
  113. self.assertEqual(self.payloads[:-1], msgs)
  114. def test_very_big_message(self):
  115. long_str = "".join([str(num) for num in range(10 ** 5)])
  116. msg = {"long_str": long_str, "stop": True}
  117. self.channel.send(msg)
  118. self.wait()
  119. self.assertEqual(msg, self.payloads[0])
  120. def test_multistream_sends(self):
  121. local_channel = self._get_channel()
  122. for c in (self.channel, local_channel):
  123. c.send("foo")
  124. self.channel.send({"stop": True})
  125. self.wait()
  126. self.assertEqual(self.payloads[:-1], ["foo", "foo"])
  127. def test_multistream_errors(self):
  128. local_channel = self._get_channel()
  129. for c in (self.channel, local_channel):
  130. c.send(None)
  131. for c in (self.channel, local_channel):
  132. c.send("foo")
  133. self.channel.send({"stop": True})
  134. self.wait()
  135. self.assertEqual(self.payloads[:-1], [None, None, "foo", "foo"])
  136. @skipIf(salt.utils.platform.is_windows(), "Windows does not support Posix IPC")
  137. class IPCMessagePubSubCase(salt.ext.tornado.testing.AsyncTestCase):
  138. """
  139. Test all of the clear msg stuff
  140. """
  141. def setUp(self):
  142. super().setUp()
  143. self.opts = {"ipc_write_buffer": 0}
  144. self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
  145. self.pub_channel = self._get_pub_channel()
  146. self.sub_channel = self._get_sub_channel()
  147. def _get_pub_channel(self):
  148. pub_channel = salt.transport.ipc.IPCMessagePublisher(
  149. self.opts, self.socket_path,
  150. )
  151. pub_channel.start()
  152. return pub_channel
  153. def _get_sub_channel(self):
  154. sub_channel = salt.transport.ipc.IPCMessageSubscriber(
  155. socket_path=self.socket_path, io_loop=self.io_loop,
  156. )
  157. sub_channel.connect(callback=self.stop)
  158. self.wait()
  159. return sub_channel
  160. def tearDown(self):
  161. super().tearDown()
  162. try:
  163. self.pub_channel.close()
  164. except OSError as exc:
  165. if exc.errno != errno.EBADF:
  166. # If its not a bad file descriptor error, raise
  167. raise
  168. try:
  169. self.sub_channel.close()
  170. except OSError as exc:
  171. if exc.errno != errno.EBADF:
  172. # If its not a bad file descriptor error, raise
  173. raise
  174. os.unlink(self.socket_path)
  175. del self.pub_channel
  176. del self.sub_channel
  177. def test_multi_client_reading(self):
  178. # To be completely fair let's create 2 clients.
  179. client1 = self.sub_channel
  180. client2 = self._get_sub_channel()
  181. call_cnt = []
  182. # Create a watchdog to be safe from hanging in sync loops (what old code did)
  183. evt = threading.Event()
  184. def close_server():
  185. if evt.wait(1):
  186. return
  187. client2.close()
  188. self.stop()
  189. watchdog = threading.Thread(target=close_server)
  190. watchdog.start()
  191. # Runs in ioloop thread so we're safe from race conditions here
  192. def handler(raw):
  193. call_cnt.append(raw)
  194. if len(call_cnt) >= 2:
  195. evt.set()
  196. self.stop()
  197. # Now let both waiting data at once
  198. client1.read_async(handler)
  199. client2.read_async(handler)
  200. self.pub_channel.publish("TEST")
  201. self.wait()
  202. self.assertEqual(len(call_cnt), 2)
  203. self.assertEqual(call_cnt[0], "TEST")
  204. self.assertEqual(call_cnt[1], "TEST")
  205. def test_sync_reading(self):
  206. # To be completely fair let's create 2 clients.
  207. client1 = self.sub_channel
  208. client2 = self._get_sub_channel()
  209. call_cnt = []
  210. # Now let both waiting data at once
  211. self.pub_channel.publish("TEST")
  212. ret1 = client1.read_sync()
  213. ret2 = client2.read_sync()
  214. self.assertEqual(ret1, "TEST")
  215. self.assertEqual(ret2, "TEST")