123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487 |
- """
- :codeauthor: Thomas Jackson <jacksontj.89@gmail.com>
- """
- import logging
- import socket
- import threading
- import salt.config
- import salt.exceptions
- import salt.ext.tornado.concurrent
- import salt.ext.tornado.gen
- import salt.ext.tornado.ioloop
- import salt.transport.client
- import salt.transport.server
- import salt.utils.platform
- import salt.utils.process
- from salt.ext.tornado.testing import AsyncTestCase, gen_test
- from salt.transport.tcp import (
- SaltMessageClient,
- SaltMessageClientPool,
- TCPPubServerChannel,
- )
- from saltfactories.utils.ports import get_unused_localhost_port
- from tests.support.helpers import flaky, slowTest
- from tests.support.mixins import AdaptedConfigurationTestCaseMixin
- from tests.support.mock import MagicMock, patch
- from tests.support.unit import TestCase, skipIf
- from tests.unit.transport.mixins import (
- PubChannelMixin,
- ReqChannelMixin,
- run_loop_in_thread,
- )
- log = logging.getLogger(__name__)
- class BaseTCPReqCase(TestCase, AdaptedConfigurationTestCaseMixin):
- """
- Test the req server/client pair
- """
- @classmethod
- def setUpClass(cls):
- if not hasattr(cls, "_handle_payload"):
- return
- ret_port = get_unused_localhost_port()
- publish_port = get_unused_localhost_port()
- tcp_master_pub_port = get_unused_localhost_port()
- tcp_master_pull_port = get_unused_localhost_port()
- tcp_master_publish_pull = get_unused_localhost_port()
- tcp_master_workers = get_unused_localhost_port()
- cls.master_config = cls.get_temp_config(
- "master",
- **{
- "transport": "tcp",
- "auto_accept": True,
- "ret_port": ret_port,
- "publish_port": publish_port,
- "tcp_master_pub_port": tcp_master_pub_port,
- "tcp_master_pull_port": tcp_master_pull_port,
- "tcp_master_publish_pull": tcp_master_publish_pull,
- "tcp_master_workers": tcp_master_workers,
- }
- )
- cls.minion_config = cls.get_temp_config(
- "minion",
- **{
- "transport": "tcp",
- "master_ip": "127.0.0.1",
- "master_port": ret_port,
- "master_uri": "tcp://127.0.0.1:{}".format(ret_port),
- }
- )
- cls.process_manager = salt.utils.process.ProcessManager(
- name="ReqServer_ProcessManager"
- )
- cls.server_channel = salt.transport.server.ReqServerChannel.factory(
- cls.master_config
- )
- cls.server_channel.pre_fork(cls.process_manager)
- cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
- cls.stop = threading.Event()
- cls.server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
- cls.server_thread = threading.Thread(
- target=run_loop_in_thread, args=(cls.io_loop, cls.stop,),
- )
- cls.server_thread.start()
- @classmethod
- def tearDownClass(cls):
- cls.server_channel.close()
- cls.stop.set()
- cls.server_thread.join()
- cls.process_manager.kill_children()
- del cls.server_channel
- @classmethod
- @salt.ext.tornado.gen.coroutine
- def _handle_payload(cls, payload):
- """
- TODO: something besides echo
- """
- raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
- @skipIf(salt.utils.platform.is_darwin(), "hanging test suite on MacOS")
- class ClearReqTestCases(BaseTCPReqCase, ReqChannelMixin):
- """
- Test all of the clear msg stuff
- """
- def setUp(self):
- self.channel = salt.transport.client.ReqChannel.factory(
- self.minion_config, crypt="clear"
- )
- def tearDown(self):
- self.channel.close()
- del self.channel
- @classmethod
- @salt.ext.tornado.gen.coroutine
- def _handle_payload(cls, payload):
- """
- TODO: something besides echo
- """
- raise salt.ext.tornado.gen.Return((payload, {"fun": "send_clear"}))
- @skipIf(salt.utils.platform.is_darwin(), "hanging test suite on MacOS")
- class AESReqTestCases(BaseTCPReqCase, ReqChannelMixin):
- def setUp(self):
- self.channel = salt.transport.client.ReqChannel.factory(self.minion_config)
- def tearDown(self):
- self.channel.close()
- del self.channel
- @classmethod
- @salt.ext.tornado.gen.coroutine
- def _handle_payload(cls, payload):
- """
- TODO: something besides echo
- """
- raise salt.ext.tornado.gen.Return((payload, {"fun": "send"}))
- # TODO: make failed returns have a specific framing so we can raise the same exception
- # on encrypted channels
- @flaky
- @slowTest
- def test_badload(self):
- """
- Test a variety of bad requests, make sure that we get some sort of error
- """
- msgs = ["", [], tuple()]
- for msg in msgs:
- with self.assertRaises(salt.exceptions.AuthenticationError):
- ret = self.channel.send(msg)
- class BaseTCPPubCase(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
- """
- Test the req server/client pair
- """
- @classmethod
- def setUpClass(cls):
- ret_port = get_unused_localhost_port()
- publish_port = get_unused_localhost_port()
- tcp_master_pub_port = get_unused_localhost_port()
- tcp_master_pull_port = get_unused_localhost_port()
- tcp_master_publish_pull = get_unused_localhost_port()
- tcp_master_workers = get_unused_localhost_port()
- cls.master_config = cls.get_temp_config(
- "master",
- **{
- "transport": "tcp",
- "auto_accept": True,
- "ret_port": ret_port,
- "publish_port": publish_port,
- "tcp_master_pub_port": tcp_master_pub_port,
- "tcp_master_pull_port": tcp_master_pull_port,
- "tcp_master_publish_pull": tcp_master_publish_pull,
- "tcp_master_workers": tcp_master_workers,
- }
- )
- cls.minion_config = cls.get_temp_config(
- "minion",
- **{
- "transport": "tcp",
- "master_ip": "127.0.0.1",
- "auth_timeout": 1,
- "master_port": ret_port,
- "master_uri": "tcp://127.0.0.1:{}".format(ret_port),
- }
- )
- cls.process_manager = salt.utils.process.ProcessManager(
- name="ReqServer_ProcessManager"
- )
- cls.server_channel = salt.transport.server.PubServerChannel.factory(
- cls.master_config
- )
- cls.server_channel.pre_fork(cls.process_manager)
- # we also require req server for auth
- cls.req_server_channel = salt.transport.server.ReqServerChannel.factory(
- cls.master_config
- )
- cls.req_server_channel.pre_fork(cls.process_manager)
- cls.io_loop = salt.ext.tornado.ioloop.IOLoop()
- cls.stop = threading.Event()
- cls.req_server_channel.post_fork(cls._handle_payload, io_loop=cls.io_loop)
- cls.server_thread = threading.Thread(
- target=run_loop_in_thread, args=(cls.io_loop, cls.stop,),
- )
- cls.server_thread.start()
- @classmethod
- def _handle_payload(cls, payload):
- """
- TODO: something besides echo
- """
- return payload, {"fun": "send_clear"}
- @classmethod
- def tearDownClass(cls):
- cls.req_server_channel.close()
- cls.server_channel.close()
- cls.stop.set()
- cls.server_thread.join()
- cls.process_manager.kill_children()
- del cls.req_server_channel
- def setUp(self):
- super().setUp()
- self._start_handlers = dict(self.io_loop._handlers)
- def tearDown(self):
- super().tearDown()
- failures = []
- for k, v in self.io_loop._handlers.items():
- if self._start_handlers.get(k) != v:
- failures.append((k, v))
- if failures:
- raise Exception("FDs still attached to the IOLoop: {}".format(failures))
- del self.channel
- del self._start_handlers
- class AsyncTCPPubChannelTest(AsyncTestCase, AdaptedConfigurationTestCaseMixin):
- @slowTest
- def test_connect_publish_port(self):
- """
- test when publish_port is not 4506
- """
- opts = self.get_temp_config("master")
- opts["master_uri"] = ""
- opts["master_ip"] = "127.0.0.1"
- opts["publish_port"] = 1234
- channel = salt.transport.tcp.AsyncTCPPubChannel(opts)
- patch_auth = MagicMock(return_value=True)
- patch_client = MagicMock(spec=SaltMessageClientPool)
- with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
- "salt.crypt.AsyncAuth.authenticated", patch_auth
- ), patch("salt.transport.tcp.SaltMessageClientPool", patch_client):
- channel.connect()
- assert patch_client.call_args[0][0]["publish_port"] == opts["publish_port"]
- @skipIf(True, "Skip until we can devote time to fix this test")
- class AsyncPubChannelTest(BaseTCPPubCase, PubChannelMixin):
- """
- Tests around the publish system
- """
- class SaltMessageClientPoolTest(AsyncTestCase):
- def setUp(self):
- super().setUp()
- sock_pool_size = 5
- with patch(
- "salt.transport.tcp.SaltMessageClient.__init__",
- MagicMock(return_value=None),
- ):
- self.message_client_pool = SaltMessageClientPool(
- {"sock_pool_size": sock_pool_size}, args=({}, "", 0)
- )
- self.original_message_clients = self.message_client_pool.message_clients
- self.message_client_pool.message_clients = [
- MagicMock() for _ in range(sock_pool_size)
- ]
- def tearDown(self):
- with patch(
- "salt.transport.tcp.SaltMessageClient.close", MagicMock(return_value=None)
- ):
- del self.original_message_clients
- super().tearDown()
- def test_send(self):
- for message_client_mock in self.message_client_pool.message_clients:
- message_client_mock.send_queue = [0, 0, 0]
- message_client_mock.send.return_value = []
- self.assertEqual([], self.message_client_pool.send())
- self.message_client_pool.message_clients[2].send_queue = [0]
- self.message_client_pool.message_clients[2].send.return_value = [1]
- self.assertEqual([1], self.message_client_pool.send())
- def test_write_to_stream(self):
- for message_client_mock in self.message_client_pool.message_clients:
- message_client_mock.send_queue = [0, 0, 0]
- message_client_mock._stream.write.return_value = []
- self.assertEqual([], self.message_client_pool.write_to_stream(""))
- self.message_client_pool.message_clients[2].send_queue = [0]
- self.message_client_pool.message_clients[2]._stream.write.return_value = [1]
- self.assertEqual([1], self.message_client_pool.write_to_stream(""))
- def test_close(self):
- self.message_client_pool.close()
- self.assertEqual([], self.message_client_pool.message_clients)
- def test_on_recv(self):
- for message_client_mock in self.message_client_pool.message_clients:
- message_client_mock.on_recv.return_value = None
- self.message_client_pool.on_recv()
- for message_client_mock in self.message_client_pool.message_clients:
- self.assertTrue(message_client_mock.on_recv.called)
- def test_connect_all(self):
- @gen_test
- def test_connect(self):
- yield self.message_client_pool.connect()
- for message_client_mock in self.message_client_pool.message_clients:
- future = salt.ext.tornado.concurrent.Future()
- future.set_result("foo")
- message_client_mock.connect.return_value = future
- self.assertIsNone(test_connect(self))
- def test_connect_partial(self):
- @gen_test(timeout=0.1)
- def test_connect(self):
- yield self.message_client_pool.connect()
- for idx, message_client_mock in enumerate(
- self.message_client_pool.message_clients
- ):
- future = salt.ext.tornado.concurrent.Future()
- if idx % 2 == 0:
- future.set_result("foo")
- message_client_mock.connect.return_value = future
- with self.assertRaises(salt.ext.tornado.ioloop.TimeoutError):
- test_connect(self)
- class SaltMessageClientCleanupTest(TestCase, AdaptedConfigurationTestCaseMixin):
- def setUp(self):
- self.listen_on = "127.0.0.1"
- self.port = get_unused_localhost_port()
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.sock.bind((self.listen_on, self.port))
- self.sock.listen(1)
- def tearDown(self):
- self.sock.close()
- del self.sock
- def test_message_client(self):
- """
- test message client cleanup on close
- """
- orig_loop = salt.ext.tornado.ioloop.IOLoop()
- orig_loop.make_current()
- opts = self.get_temp_config("master")
- client = SaltMessageClient(opts, self.listen_on, self.port)
- # Mock the io_loop's stop method so we know when it has been called.
- orig_loop.real_stop = orig_loop.stop
- orig_loop.stop_called = False
- def stop(*args, **kwargs):
- orig_loop.stop_called = True
- orig_loop.real_stop()
- orig_loop.stop = stop
- try:
- assert client.io_loop == orig_loop
- client.io_loop.run_sync(client.connect)
- # Ensure we are testing the _read_until_future and io_loop teardown
- assert client._stream is not None
- assert client._read_until_future is not None
- assert orig_loop.stop_called is True
- # The run_sync call will set stop_called, reset it
- orig_loop.stop_called = False
- client.close()
- # Stop should be called again, client's io_loop should be None
- assert orig_loop.stop_called is True
- assert client.io_loop is None
- finally:
- orig_loop.stop = orig_loop.real_stop
- del orig_loop.real_stop
- del orig_loop.stop_called
- class TCPPubServerChannelTest(TestCase, AdaptedConfigurationTestCaseMixin):
- @patch("salt.master.SMaster.secrets")
- @patch("salt.crypt.Crypticle")
- @patch("salt.utils.asynchronous.SyncWrapper")
- def test_publish_filtering(self, sync_wrapper, crypticle, secrets):
- opts = self.get_temp_config("master")
- opts["sign_pub_messages"] = False
- channel = TCPPubServerChannel(opts)
- wrap = MagicMock()
- crypt = MagicMock()
- crypt.dumps.return_value = {"test": "value"}
- secrets.return_value = {"aes": {"secret": None}}
- crypticle.return_value = crypt
- sync_wrapper.return_value = wrap
- # try simple publish with glob tgt_type
- channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
- payload = wrap.send.call_args[0][0]
- # verify we send it without any specific topic
- assert "topic_lst" not in payload
- # try simple publish with list tgt_type
- channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
- payload = wrap.send.call_args[0][0]
- # verify we send it with correct topic
- assert "topic_lst" in payload
- self.assertEqual(payload["topic_lst"], ["minion01"])
- # try with syndic settings
- opts["order_masters"] = True
- channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
- payload = wrap.send.call_args[0][0]
- # verify we send it without topic for syndics
- assert "topic_lst" not in payload
- @patch("salt.utils.minions.CkMinions.check_minions")
- @patch("salt.master.SMaster.secrets")
- @patch("salt.crypt.Crypticle")
- @patch("salt.utils.asynchronous.SyncWrapper")
- def test_publish_filtering_str_list(
- self, sync_wrapper, crypticle, secrets, check_minions
- ):
- opts = self.get_temp_config("master")
- opts["sign_pub_messages"] = False
- channel = TCPPubServerChannel(opts)
- wrap = MagicMock()
- crypt = MagicMock()
- crypt.dumps.return_value = {"test": "value"}
- secrets.return_value = {"aes": {"secret": None}}
- crypticle.return_value = crypt
- sync_wrapper.return_value = wrap
- check_minions.return_value = {"minions": ["minion02"]}
- # try simple publish with list tgt_type
- channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
- payload = wrap.send.call_args[0][0]
- # verify we send it with correct topic
- assert "topic_lst" in payload
- self.assertEqual(payload["topic_lst"], ["minion02"])
- # verify it was correctly calling check_minions
- check_minions.assert_called_with("minion02", tgt_type="list")
|