test_ipc.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # -*- coding: utf-8 -*-
  2. """
  3. :codeauthor: Mike Place <mp@saltstack.com>
  4. """
  5. # Import python libs
  6. from __future__ import absolute_import, print_function, unicode_literals
  7. import errno
  8. import logging
  9. import os
  10. import socket
  11. import threading
  12. import pytest
  13. import salt.config
  14. import salt.exceptions
  15. import salt.ext.tornado.gen
  16. import salt.ext.tornado.ioloop
  17. import salt.ext.tornado.testing
  18. import salt.transport.client
  19. import salt.transport.ipc
  20. import salt.transport.server
  21. import salt.utils.platform
  22. from salt.ext import six
  23. from salt.ext.six.moves import range
  24. from tests.support.mock import MagicMock
  25. # Import Salt Testing libs
  26. from tests.support.runtests import RUNTIME_VARS
  27. from tests.support.unit import skipIf
  28. log = logging.getLogger(__name__)
  29. @skipIf(salt.utils.platform.is_windows(), "Windows does not support Posix IPC")
  30. class BaseIPCReqCase(salt.ext.tornado.testing.AsyncTestCase):
  31. """
  32. Test the req server/client pair
  33. """
  34. def setUp(self):
  35. super(BaseIPCReqCase, self).setUp()
  36. # self._start_handlers = dict(self.io_loop._handlers)
  37. self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
  38. self.server_channel = salt.transport.ipc.IPCMessageServer(
  39. self.socket_path,
  40. io_loop=self.io_loop,
  41. payload_handler=self._handle_payload,
  42. )
  43. self.server_channel.start()
  44. self.payloads = []
  45. def tearDown(self):
  46. super(BaseIPCReqCase, self).tearDown()
  47. # failures = []
  48. try:
  49. self.server_channel.close()
  50. except socket.error as exc:
  51. if exc.errno != errno.EBADF:
  52. # If its not a bad file descriptor error, raise
  53. raise
  54. os.unlink(self.socket_path)
  55. # for k, v in six.iteritems(self.io_loop._handlers):
  56. # if self._start_handlers.get(k) != v:
  57. # failures.append((k, v))
  58. # if len(failures) > 0:
  59. # raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
  60. del self.payloads
  61. del self.socket_path
  62. del self.server_channel
  63. # del self._start_handlers
  64. @salt.ext.tornado.gen.coroutine
  65. def _handle_payload(self, payload, reply_func):
  66. self.payloads.append(payload)
  67. yield reply_func(payload)
  68. if isinstance(payload, dict) and payload.get("stop"):
  69. self.stop()
  70. class IPCMessageClient(BaseIPCReqCase):
  71. """
  72. Test all of the clear msg stuff
  73. """
  74. def _get_channel(self):
  75. if not hasattr(self, "channel") or self.channel is None:
  76. self.channel = salt.transport.ipc.IPCMessageClient(
  77. socket_path=self.socket_path, io_loop=self.io_loop,
  78. )
  79. self.channel.connect(callback=self.stop)
  80. self.wait()
  81. return self.channel
  82. def setUp(self):
  83. super(IPCMessageClient, self).setUp()
  84. self.channel = self._get_channel()
  85. def tearDown(self):
  86. super(IPCMessageClient, self).tearDown()
  87. try:
  88. # Make sure we close no matter what we've done in the tests
  89. del self.channel
  90. except socket.error as exc:
  91. if exc.errno != errno.EBADF:
  92. # If its not a bad file descriptor error, raise
  93. raise
  94. finally:
  95. self.channel = None
  96. def test_singleton(self):
  97. channel = self._get_channel()
  98. assert self.channel is channel
  99. # Delete the local channel. Since there's still one more refefence
  100. # __del__ wasn't called
  101. del channel
  102. assert self.channel
  103. msg = {"foo": "bar", "stop": True}
  104. self.channel.send(msg)
  105. self.wait()
  106. self.assertEqual(self.payloads[0], msg)
  107. def test_basic_send(self):
  108. msg = {"foo": "bar", "stop": True}
  109. self.channel.send(msg)
  110. self.wait()
  111. self.assertEqual(self.payloads[0], msg)
  112. @pytest.mark.slow_test(seconds=1) # Test takes >0.1 and <=1 seconds
  113. def test_many_send(self):
  114. msgs = []
  115. self.server_channel.stream_handler = MagicMock()
  116. for i in range(0, 1000):
  117. msgs.append("test_many_send_{0}".format(i))
  118. for i in msgs:
  119. self.channel.send(i)
  120. self.channel.send({"stop": True})
  121. self.wait()
  122. self.assertEqual(self.payloads[:-1], msgs)
  123. def test_very_big_message(self):
  124. long_str = "".join([six.text_type(num) for num in range(10 ** 5)])
  125. msg = {"long_str": long_str, "stop": True}
  126. self.channel.send(msg)
  127. self.wait()
  128. self.assertEqual(msg, self.payloads[0])
  129. def test_multistream_sends(self):
  130. local_channel = self._get_channel()
  131. for c in (self.channel, local_channel):
  132. c.send("foo")
  133. self.channel.send({"stop": True})
  134. self.wait()
  135. self.assertEqual(self.payloads[:-1], ["foo", "foo"])
  136. def test_multistream_errors(self):
  137. local_channel = self._get_channel()
  138. for c in (self.channel, local_channel):
  139. c.send(None)
  140. for c in (self.channel, local_channel):
  141. c.send("foo")
  142. self.channel.send({"stop": True})
  143. self.wait()
  144. self.assertEqual(self.payloads[:-1], [None, None, "foo", "foo"])
  145. @skipIf(salt.utils.platform.is_windows(), "Windows does not support Posix IPC")
  146. class IPCMessagePubSubCase(salt.ext.tornado.testing.AsyncTestCase):
  147. """
  148. Test all of the clear msg stuff
  149. """
  150. def setUp(self):
  151. super(IPCMessagePubSubCase, self).setUp()
  152. self.opts = {"ipc_write_buffer": 0}
  153. self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
  154. self.pub_channel = self._get_pub_channel()
  155. self.sub_channel = self._get_sub_channel()
  156. def _get_pub_channel(self):
  157. pub_channel = salt.transport.ipc.IPCMessagePublisher(
  158. self.opts, self.socket_path,
  159. )
  160. pub_channel.start()
  161. return pub_channel
  162. def _get_sub_channel(self):
  163. sub_channel = salt.transport.ipc.IPCMessageSubscriber(
  164. socket_path=self.socket_path, io_loop=self.io_loop,
  165. )
  166. sub_channel.connect(callback=self.stop)
  167. self.wait()
  168. return sub_channel
  169. def tearDown(self):
  170. super(IPCMessagePubSubCase, self).tearDown()
  171. try:
  172. self.pub_channel.close()
  173. except socket.error as exc:
  174. if exc.errno != errno.EBADF:
  175. # If its not a bad file descriptor error, raise
  176. raise
  177. try:
  178. self.sub_channel.close()
  179. except socket.error as exc:
  180. if exc.errno != errno.EBADF:
  181. # If its not a bad file descriptor error, raise
  182. raise
  183. os.unlink(self.socket_path)
  184. del self.pub_channel
  185. del self.sub_channel
  186. def test_multi_client_reading(self):
  187. # To be completely fair let's create 2 clients.
  188. client1 = self.sub_channel
  189. client2 = self._get_sub_channel()
  190. call_cnt = []
  191. # Create a watchdog to be safe from hanging in sync loops (what old code did)
  192. evt = threading.Event()
  193. def close_server():
  194. if evt.wait(1):
  195. return
  196. client2.close()
  197. self.stop()
  198. watchdog = threading.Thread(target=close_server)
  199. watchdog.start()
  200. # Runs in ioloop thread so we're safe from race conditions here
  201. def handler(raw):
  202. call_cnt.append(raw)
  203. if len(call_cnt) >= 2:
  204. evt.set()
  205. self.stop()
  206. # Now let both waiting data at once
  207. client1.read_async(handler)
  208. client2.read_async(handler)
  209. self.pub_channel.publish("TEST")
  210. self.wait()
  211. self.assertEqual(len(call_cnt), 2)
  212. self.assertEqual(call_cnt[0], "TEST")
  213. self.assertEqual(call_cnt[1], "TEST")
  214. def test_sync_reading(self):
  215. # To be completely fair let's create 2 clients.
  216. client1 = self.sub_channel
  217. client2 = self._get_sub_channel()
  218. call_cnt = []
  219. # Now let both waiting data at once
  220. self.pub_channel.publish("TEST")
  221. ret1 = client1.read_sync()
  222. ret2 = client2.read_sync()
  223. self.assertEqual(ret1, "TEST")
  224. self.assertEqual(ret2, "TEST")