test_ipc.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # -*- coding: utf-8 -*-
  2. """
  3. :codeauthor: Mike Place <mp@saltstack.com>
  4. """
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import errno
  8. import logging
  9. import os
  10. import socket
  11. import threading
  12. import salt.config
  13. import salt.exceptions
  14. import salt.ext.tornado.gen
  15. import salt.ext.tornado.ioloop
  16. import salt.ext.tornado.testing
  17. import salt.transport.client
  18. import salt.transport.ipc
  19. import salt.transport.server
  20. import salt.utils.platform
  21. from salt.ext import six
  22. from salt.ext.six.moves import range
  23. from tests.support.mock import MagicMock
  24. # Import Salt Testing libs
  25. from tests.support.runtests import RUNTIME_VARS
  26. from tests.support.unit import skipIf
  27. log = logging.getLogger(__name__)
  28. @skipIf(salt.utils.platform.is_windows(), "Windows does not support Posix IPC")
  29. class BaseIPCReqCase(salt.ext.tornado.testing.AsyncTestCase):
  30. """
  31. Test the req server/client pair
  32. """
  33. def setUp(self):
  34. super(BaseIPCReqCase, self).setUp()
  35. # self._start_handlers = dict(self.io_loop._handlers)
  36. self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
  37. self.server_channel = salt.transport.ipc.IPCMessageServer(
  38. self.socket_path,
  39. io_loop=self.io_loop,
  40. payload_handler=self._handle_payload,
  41. )
  42. self.server_channel.start()
  43. self.payloads = []
  44. def tearDown(self):
  45. super(BaseIPCReqCase, self).tearDown()
  46. # failures = []
  47. try:
  48. self.server_channel.close()
  49. except socket.error as exc:
  50. if exc.errno != errno.EBADF:
  51. # If its not a bad file descriptor error, raise
  52. raise
  53. os.unlink(self.socket_path)
  54. # for k, v in six.iteritems(self.io_loop._handlers):
  55. # if self._start_handlers.get(k) != v:
  56. # failures.append((k, v))
  57. # if len(failures) > 0:
  58. # raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  59. del self.payloads
  60. del self.socket_path
  61. del self.server_channel
  62. # del self._start_handlers
  63. @salt.ext.tornado.gen.coroutine
  64. def _handle_payload(self, payload, reply_func):
  65. self.payloads.append(payload)
  66. yield reply_func(payload)
  67. if isinstance(payload, dict) and payload.get("stop"):
  68. self.stop()
  69. class IPCMessageClient(BaseIPCReqCase):
  70. """
  71. Test all of the clear msg stuff
  72. """
  73. def _get_channel(self):
  74. if not hasattr(self, "channel") or self.channel is None:
  75. self.channel = salt.transport.ipc.IPCMessageClient(
  76. socket_path=self.socket_path, io_loop=self.io_loop,
  77. )
  78. self.channel.connect(callback=self.stop)
  79. self.wait()
  80. return self.channel
  81. def setUp(self):
  82. super(IPCMessageClient, self).setUp()
  83. self.channel = self._get_channel()
  84. def tearDown(self):
  85. super(IPCMessageClient, self).tearDown()
  86. try:
  87. # Make sure we close no matter what we've done in the tests
  88. del self.channel
  89. except socket.error as exc:
  90. if exc.errno != errno.EBADF:
  91. # If its not a bad file descriptor error, raise
  92. raise
  93. finally:
  94. self.channel = None
  95. def test_singleton(self):
  96. channel = self._get_channel()
  97. assert self.channel is channel
  98. # Delete the local channel. Since there's still one more refefence
  99. # __del__ wasn't called
  100. del channel
  101. assert self.channel
  102. msg = {"foo": "bar", "stop": True}
  103. self.channel.send(msg)
  104. self.wait()
  105. self.assertEqual(self.payloads[0], msg)
  106. def test_basic_send(self):
  107. msg = {"foo": "bar", "stop": True}
  108. self.channel.send(msg)
  109. self.wait()
  110. self.assertEqual(self.payloads[0], msg)
  111. def test_many_send(self):
  112. msgs = []
  113. self.server_channel.stream_handler = MagicMock()
  114. for i in range(0, 1000):
  115. msgs.append("test_many_send_{0}".format(i))
  116. for i in msgs:
  117. self.channel.send(i)
  118. self.channel.send({"stop": True})
  119. self.wait()
  120. self.assertEqual(self.payloads[:-1], msgs)
  121. def test_very_big_message(self):
  122. long_str = "".join([six.text_type(num) for num in range(10 ** 5)])
  123. msg = {"long_str": long_str, "stop": True}
  124. self.channel.send(msg)
  125. self.wait()
  126. self.assertEqual(msg, self.payloads[0])
  127. def test_multistream_sends(self):
  128. local_channel = self._get_channel()
  129. for c in (self.channel, local_channel):
  130. c.send("foo")
  131. self.channel.send({"stop": True})
  132. self.wait()
  133. self.assertEqual(self.payloads[:-1], ["foo", "foo"])
  134. def test_multistream_errors(self):
  135. local_channel = self._get_channel()
  136. for c in (self.channel, local_channel):
  137. c.send(None)
  138. for c in (self.channel, local_channel):
  139. c.send("foo")
  140. self.channel.send({"stop": True})
  141. self.wait()
  142. self.assertEqual(self.payloads[:-1], [None, None, "foo", "foo"])
  143. @skipIf(salt.utils.platform.is_windows(), "Windows does not support Posix IPC")
  144. class IPCMessagePubSubCase(salt.ext.tornado.testing.AsyncTestCase):
  145. """
  146. Test all of the clear msg stuff
  147. """
  148. def setUp(self):
  149. super(IPCMessagePubSubCase, self).setUp()
  150. self.opts = {"ipc_write_buffer": 0}
  151. self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
  152. self.pub_channel = self._get_pub_channel()
  153. self.sub_channel = self._get_sub_channel()
  154. def _get_pub_channel(self):
  155. pub_channel = salt.transport.ipc.IPCMessagePublisher(
  156. self.opts, self.socket_path,
  157. )
  158. pub_channel.start()
  159. return pub_channel
  160. def _get_sub_channel(self):
  161. sub_channel = salt.transport.ipc.IPCMessageSubscriber(
  162. socket_path=self.socket_path, io_loop=self.io_loop,
  163. )
  164. sub_channel.connect(callback=self.stop)
  165. self.wait()
  166. return sub_channel
  167. def tearDown(self):
  168. super(IPCMessagePubSubCase, self).tearDown()
  169. try:
  170. self.pub_channel.close()
  171. except socket.error as exc:
  172. if exc.errno != errno.EBADF:
  173. # If its not a bad file descriptor error, raise
  174. raise
  175. try:
  176. self.sub_channel.close()
  177. except socket.error as exc:
  178. if exc.errno != errno.EBADF:
  179. # If its not a bad file descriptor error, raise
  180. raise
  181. os.unlink(self.socket_path)
  182. del self.pub_channel
  183. del self.sub_channel
  184. def test_multi_client_reading(self):
  185. # To be completely fair let's create 2 clients.
  186. client1 = self.sub_channel
  187. client2 = self._get_sub_channel()
  188. call_cnt = []
  189. # Create a watchdog to be safe from hanging in sync loops (what old code did)
  190. evt = threading.Event()
  191. def close_server():
  192. if evt.wait(1):
  193. return
  194. client2.close()
  195. self.stop()
  196. watchdog = threading.Thread(target=close_server)
  197. watchdog.start()
  198. # Runs in ioloop thread so we're safe from race conditions here
  199. def handler(raw):
  200. call_cnt.append(raw)
  201. if len(call_cnt) >= 2:
  202. evt.set()
  203. self.stop()
  204. # Now let both waiting data at once
  205. client1.read_async(handler)
  206. client2.read_async(handler)
  207. self.pub_channel.publish("TEST")
  208. self.wait()
  209. self.assertEqual(len(call_cnt), 2)
  210. self.assertEqual(call_cnt[0], "TEST")
  211. self.assertEqual(call_cnt[1], "TEST")
  212. def test_sync_reading(self):
  213. # To be completely fair let's create 2 clients.
  214. client1 = self.sub_channel
  215. client2 = self._get_sub_channel()
  216. call_cnt = []
  217. # Now let both waiting data at once
  218. self.pub_channel.publish("TEST")
  219. ret1 = client1.read_sync()
  220. ret2 = client2.read_sync()
  221. self.assertEqual(ret1, "TEST")
  222. self.assertEqual(ret2, "TEST")