test_ssh.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # -*- coding: utf-8 -*-
  2. """
  3. Test the ssh module
  4. """
  5. from __future__ import absolute_import, print_function, unicode_literals
  6. import os
  7. import shutil
  8. import pytest
  9. import salt.utils.files
  10. import salt.utils.platform
  11. from salt.ext.tornado.httpclient import HTTPClient
  12. from tests.support.case import ModuleCase
  13. from tests.support.runtests import RUNTIME_VARS
  14. GITHUB_FINGERPRINT = "9d:38:5b:83:a9:17:52:92:56:1a:5e:c4:d4:81:8e:0a:ca:51:a2:64:f1:74:20:11:2e:f8:8a:c3:a1:39:49:8f"
  15. def check_status():
  16. """
  17. Check the status of Github for remote operations
  18. """
  19. try:
  20. return HTTPClient().fetch("http://github.com").code == 200
  21. except Exception: # pylint: disable=broad-except
  22. return False
  23. @pytest.mark.windows_whitelisted
  24. @pytest.mark.skip_if_binaries_missing("ssh", "ssh-keygen", check_all=True)
  25. class SSHModuleTest(ModuleCase):
  26. """
  27. Test the ssh module
  28. """
  29. @classmethod
  30. def setUpClass(cls):
  31. cls.subsalt_dir = os.path.join(RUNTIME_VARS.TMP, "subsalt")
  32. cls.authorized_keys = os.path.join(cls.subsalt_dir, "authorized_keys")
  33. cls.known_hosts = os.path.join(cls.subsalt_dir, "known_hosts")
  34. def setUp(self):
  35. """
  36. Set up the ssh module tests
  37. """
  38. if not check_status():
  39. self.skipTest("External source, github.com is down")
  40. super(SSHModuleTest, self).setUp()
  41. if not os.path.isdir(self.subsalt_dir):
  42. os.makedirs(self.subsalt_dir)
  43. ssh_raw_path = os.path.join(RUNTIME_VARS.FILES, "ssh", "raw")
  44. with salt.utils.files.fopen(ssh_raw_path) as fd:
  45. self.key = fd.read().strip()
  46. def tearDown(self):
  47. """
  48. Tear down the ssh module tests
  49. """
  50. if os.path.isdir(self.subsalt_dir):
  51. shutil.rmtree(self.subsalt_dir)
  52. super(SSHModuleTest, self).tearDown()
  53. del self.key
  54. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  55. def test_auth_keys(self):
  56. """
  57. test ssh.auth_keys
  58. """
  59. shutil.copyfile(
  60. os.path.join(RUNTIME_VARS.FILES, "ssh", "authorized_keys"),
  61. self.authorized_keys,
  62. )
  63. user = "root"
  64. if salt.utils.platform.is_windows():
  65. user = "Administrator"
  66. ret = self.run_function("ssh.auth_keys", [user, self.authorized_keys])
  67. self.assertEqual(len(list(ret.items())), 1) # exactly one key is found
  68. key_data = list(ret.items())[0][1]
  69. try:
  70. self.assertEqual(key_data["comment"], "github.com")
  71. self.assertEqual(key_data["enc"], "ssh-rsa")
  72. self.assertEqual(
  73. key_data["options"], ['command="/usr/local/lib/ssh-helper"']
  74. )
  75. self.assertEqual(key_data["fingerprint"], GITHUB_FINGERPRINT)
  76. except AssertionError as exc:
  77. raise AssertionError(
  78. "AssertionError: {0}. Function returned: {1}".format(exc, ret)
  79. )
  80. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  81. def test_bad_enctype(self):
  82. """
  83. test to make sure that bad key encoding types don't generate an
  84. invalid key entry in authorized_keys
  85. """
  86. shutil.copyfile(
  87. os.path.join(RUNTIME_VARS.FILES, "ssh", "authorized_badkeys"),
  88. self.authorized_keys,
  89. )
  90. ret = self.run_function("ssh.auth_keys", ["root", self.authorized_keys])
  91. # The authorized_badkeys file contains a key with an invalid ssh key
  92. # encoding (dsa-sha2-nistp256 instead of ecdsa-sha2-nistp256)
  93. # auth_keys should skip any keys with invalid encodings. Internally
  94. # the minion will throw a CommandExecutionError so the
  95. # user will get an indicator of what went wrong.
  96. self.assertEqual(len(list(ret.items())), 0) # Zero keys found
  97. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  98. def test_get_known_host_entries(self):
  99. """
  100. Check that known host information is returned from ~/.ssh/config
  101. """
  102. shutil.copyfile(
  103. os.path.join(RUNTIME_VARS.FILES, "ssh", "known_hosts"), self.known_hosts
  104. )
  105. arg = ["root", "github.com"]
  106. kwargs = {"config": self.known_hosts}
  107. ret = self.run_function("ssh.get_known_host_entries", arg, **kwargs)[0]
  108. try:
  109. self.assertEqual(ret["enc"], "ssh-rsa")
  110. self.assertEqual(ret["key"], self.key)
  111. self.assertEqual(ret["fingerprint"], GITHUB_FINGERPRINT)
  112. except AssertionError as exc:
  113. raise AssertionError(
  114. "AssertionError: {0}. Function returned: {1}".format(exc, ret)
  115. )
  116. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  117. def test_recv_known_host_entries(self):
  118. """
  119. Check that known host information is returned from remote host
  120. """
  121. ret = self.run_function("ssh.recv_known_host_entries", ["github.com"])
  122. try:
  123. self.assertNotEqual(ret, None)
  124. self.assertEqual(ret[0]["enc"], "ssh-rsa")
  125. self.assertEqual(ret[0]["key"], self.key)
  126. self.assertEqual(ret[0]["fingerprint"], GITHUB_FINGERPRINT)
  127. except AssertionError as exc:
  128. raise AssertionError(
  129. "AssertionError: {0}. Function returned: {1}".format(exc, ret)
  130. )
  131. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  132. def test_check_known_host_add(self):
  133. """
  134. Check known hosts by its fingerprint. File needs to be updated
  135. """
  136. arg = ["root", "github.com"]
  137. kwargs = {"fingerprint": GITHUB_FINGERPRINT, "config": self.known_hosts}
  138. ret = self.run_function("ssh.check_known_host", arg, **kwargs)
  139. self.assertEqual(ret, "add")
  140. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  141. def test_check_known_host_update(self):
  142. """
  143. ssh.check_known_host update verification
  144. """
  145. shutil.copyfile(
  146. os.path.join(RUNTIME_VARS.FILES, "ssh", "known_hosts"), self.known_hosts
  147. )
  148. arg = ["root", "github.com"]
  149. kwargs = {"config": self.known_hosts}
  150. # wrong fingerprint
  151. ret = self.run_function(
  152. "ssh.check_known_host", arg, **dict(kwargs, fingerprint="aa:bb:cc:dd")
  153. )
  154. self.assertEqual(ret, "update")
  155. # wrong keyfile
  156. ret = self.run_function("ssh.check_known_host", arg, **dict(kwargs, key="YQ=="))
  157. self.assertEqual(ret, "update")
  158. @pytest.mark.slow_test(seconds=30) # Test takes >10 and <=30 seconds
  159. def test_check_known_host_exists(self):
  160. """
  161. Verify check_known_host_exists
  162. """
  163. shutil.copyfile(
  164. os.path.join(RUNTIME_VARS.FILES, "ssh", "known_hosts"), self.known_hosts
  165. )
  166. arg = ["root", "github.com"]
  167. kwargs = {"config": self.known_hosts}
  168. # wrong fingerprint
  169. ret = self.run_function(
  170. "ssh.check_known_host", arg, **dict(kwargs, fingerprint=GITHUB_FINGERPRINT)
  171. )
  172. self.assertEqual(ret, "exists")
  173. # wrong keyfile
  174. ret = self.run_function(
  175. "ssh.check_known_host", arg, **dict(kwargs, key=self.key)
  176. )
  177. self.assertEqual(ret, "exists")
  178. @pytest.mark.slow_test(seconds=60) # Test takes >30 and <=60 seconds
  179. def test_rm_known_host(self):
  180. """
  181. ssh.rm_known_host
  182. """
  183. shutil.copyfile(
  184. os.path.join(RUNTIME_VARS.FILES, "ssh", "known_hosts"), self.known_hosts
  185. )
  186. arg = ["root", "github.com"]
  187. kwargs = {"config": self.known_hosts, "key": self.key}
  188. # before removal
  189. ret = self.run_function("ssh.check_known_host", arg, **kwargs)
  190. self.assertEqual(ret, "exists")
  191. # remove
  192. self.run_function("ssh.rm_known_host", arg, config=self.known_hosts)
  193. # after removal
  194. ret = self.run_function("ssh.check_known_host", arg, **kwargs)
  195. self.assertEqual(ret, "add")
  196. @pytest.mark.slow_test(seconds=60) # Test takes >30 and <=60 seconds
  197. def test_set_known_host(self):
  198. """
  199. ssh.set_known_host
  200. """
  201. # add item
  202. ret = self.run_function(
  203. "ssh.set_known_host", ["root", "github.com"], config=self.known_hosts
  204. )
  205. try:
  206. self.assertEqual(ret["status"], "updated")
  207. self.assertEqual(ret["old"], None)
  208. self.assertEqual(ret["new"][0]["fingerprint"], GITHUB_FINGERPRINT)
  209. except AssertionError as exc:
  210. raise AssertionError(
  211. "AssertionError: {0}. Function returned: {1}".format(exc, ret)
  212. )
  213. # check that item does exist
  214. ret = self.run_function(
  215. "ssh.get_known_host_entries",
  216. ["root", "github.com"],
  217. config=self.known_hosts,
  218. )[0]
  219. try:
  220. self.assertEqual(ret["fingerprint"], GITHUB_FINGERPRINT)
  221. except AssertionError as exc:
  222. raise AssertionError(
  223. "AssertionError: {0}. Function returned: {1}".format(exc, ret)
  224. )
  225. # add the same item once again
  226. ret = self.run_function(
  227. "ssh.set_known_host", ["root", "github.com"], config=self.known_hosts
  228. )
  229. try:
  230. self.assertEqual(ret["status"], "exists")
  231. except AssertionError as exc:
  232. raise AssertionError(
  233. "AssertionError: {0}. Function returned: {1}".format(exc, ret)
  234. )