test_pyobjects.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. import logging
  4. import os
  5. import shutil
  6. import tempfile
  7. import textwrap
  8. import uuid
  9. import jinja2
  10. import salt.config
  11. import salt.state
  12. import salt.utils.files
  13. from salt.template import compile_template
  14. from salt.utils.odict import OrderedDict
  15. from salt.utils.pyobjects import (
  16. DuplicateState,
  17. InvalidFunction,
  18. Registry,
  19. SaltObject,
  20. State,
  21. StateFactory,
  22. )
  23. from tests.support.helpers import slowTest
  24. from tests.support.runtests import RUNTIME_VARS
  25. from tests.support.unit import TestCase
  26. log = logging.getLogger(__name__)
  27. File = StateFactory("file")
  28. Service = StateFactory("service")
  29. pydmesg_expected = {
  30. "file.managed": [
  31. {"group": "root"},
  32. {"mode": "0755"},
  33. {"require": [{"file": "/usr/local/bin"}]},
  34. {"source": "salt://debian/files/pydmesg.py"},
  35. {"user": "root"},
  36. ]
  37. }
  38. pydmesg_salt_expected = OrderedDict([("/usr/local/bin/pydmesg", pydmesg_expected)])
  39. pydmesg_kwargs = dict(
  40. user="root", group="root", mode="0755", source="salt://debian/files/pydmesg.py"
  41. )
  42. basic_template = """#!pyobjects
  43. File.directory('/tmp', mode='1777', owner='root', group='root')
  44. """
  45. invalid_template = """#!pyobjects
  46. File.fail('/tmp')
  47. """
  48. include_template = """#!pyobjects
  49. include('http')
  50. """
  51. extend_template = """#!pyobjects
  52. include('http')
  53. from salt.utils.pyobjects import StateFactory
  54. Service = StateFactory('service')
  55. Service.running(extend('apache'), watch=[{'file': '/etc/file'}])
  56. """
  57. map_prefix = """\
  58. #!pyobjects
  59. from salt.utils.pyobjects import StateFactory
  60. Service = StateFactory('service')
  61. {% macro priority(value) %}
  62. priority = {{ value }}
  63. {% endmacro %}
  64. class Samba(Map):
  65. """
  66. map_suffix = """
  67. with Pkg.installed("samba", names=[Samba.server, Samba.client]):
  68. Service.running("samba", name=Samba.service)
  69. """
  70. map_data = {
  71. "debian": " class Debian:\n"
  72. " server = 'samba'\n"
  73. " client = 'samba-client'\n"
  74. " service = 'samba'\n",
  75. "centos": " class RougeChapeau:\n"
  76. " __match__ = 'RedHat'\n"
  77. " server = 'samba'\n"
  78. " client = 'samba'\n"
  79. " service = 'smb'\n",
  80. "ubuntu": " class Ubuntu:\n"
  81. " __grain__ = 'os'\n"
  82. " service = 'smbd'\n",
  83. }
  84. import_template = """#!pyobjects
  85. import salt://map.sls
  86. Pkg.removed("samba-imported", names=[map.Samba.server, map.Samba.client])
  87. """
  88. recursive_map_template = """#!pyobjects
  89. from salt://map.sls import Samba
  90. class CustomSamba(Samba):
  91. pass
  92. """
  93. recursive_import_template = """#!pyobjects
  94. from salt://recursive_map.sls import CustomSamba
  95. Pkg.removed("samba-imported", names=[CustomSamba.server, CustomSamba.client])"""
  96. scope_test_import_template = """#!pyobjects
  97. from salt://recursive_map.sls import CustomSamba
  98. # since we import CustomSamba we should shouldn't be able to see Samba
  99. Pkg.removed("samba-imported", names=[Samba.server, Samba.client])"""
  100. from_import_template = """#!pyobjects
  101. # this spacing is like this on purpose to ensure it's stripped properly
  102. from salt://map.sls import Samba
  103. Pkg.removed("samba-imported", names=[Samba.server, Samba.client])
  104. """
  105. import_as_template = """#!pyobjects
  106. from salt://map.sls import Samba as Other
  107. Pkg.removed("samba-imported", names=[Other.server, Other.client])
  108. """
  109. random_password_template = """#!pyobjects
  110. import random, string
  111. password = ''.join([random.SystemRandom().choice(
  112. string.ascii_letters + string.digits) for _ in range(20)])
  113. """
  114. random_password_import_template = """#!pyobjects
  115. from salt://password.sls import password
  116. """
  117. requisite_implicit_list_template = """#!pyobjects
  118. from salt.utils.pyobjects import StateFactory
  119. Service = StateFactory('service')
  120. with Pkg.installed("pkg"):
  121. Service.running("service", watch=File("file"), require=Cmd("cmd"))
  122. """
  123. class MapBuilder(object):
  124. def build_map(self, template=None):
  125. """
  126. Build from a specific template or just use a default if no template
  127. is passed to this function.
  128. """
  129. if template is None:
  130. template = textwrap.dedent(
  131. """\
  132. {{ ubuntu }}
  133. {{ centos }}
  134. {{ debian }}
  135. """
  136. )
  137. full_template = map_prefix + template + map_suffix
  138. ret = jinja2.Template(full_template).render(**map_data)
  139. log.debug("built map: \n%s", ret)
  140. return ret
  141. class StateTests(TestCase):
  142. def setUp(self):
  143. Registry.empty()
  144. def test_serialization(self):
  145. f = State(
  146. "/usr/local/bin/pydmesg",
  147. "file",
  148. "managed",
  149. require=File("/usr/local/bin"),
  150. **pydmesg_kwargs
  151. )
  152. self.assertEqual(f(), pydmesg_expected)
  153. def test_factory_serialization(self):
  154. File.managed(
  155. "/usr/local/bin/pydmesg", require=File("/usr/local/bin"), **pydmesg_kwargs
  156. )
  157. self.assertEqual(Registry.states["/usr/local/bin/pydmesg"], pydmesg_expected)
  158. def test_context_manager(self):
  159. with File("/usr/local/bin"):
  160. pydmesg = File.managed("/usr/local/bin/pydmesg", **pydmesg_kwargs)
  161. self.assertEqual(
  162. Registry.states["/usr/local/bin/pydmesg"], pydmesg_expected
  163. )
  164. with pydmesg:
  165. File.managed("/tmp/something", owner="root")
  166. self.assertEqual(
  167. Registry.states["/tmp/something"],
  168. {
  169. "file.managed": [
  170. {"owner": "root"},
  171. {
  172. "require": [
  173. {"file": "/usr/local/bin"},
  174. {"file": "/usr/local/bin/pydmesg"},
  175. ]
  176. },
  177. ]
  178. },
  179. )
  180. def test_salt_data(self):
  181. File.managed(
  182. "/usr/local/bin/pydmesg", require=File("/usr/local/bin"), **pydmesg_kwargs
  183. )
  184. self.assertEqual(Registry.states["/usr/local/bin/pydmesg"], pydmesg_expected)
  185. self.assertEqual(Registry.salt_data(), pydmesg_salt_expected)
  186. self.assertEqual(Registry.states, OrderedDict())
  187. def test_duplicates(self):
  188. def add_dup():
  189. File.managed("dup", name="/dup")
  190. add_dup()
  191. self.assertRaises(DuplicateState, add_dup)
  192. Service.running("dup", name="dup-service")
  193. self.assertEqual(
  194. Registry.states,
  195. OrderedDict(
  196. [
  197. (
  198. "dup",
  199. OrderedDict(
  200. [
  201. ("file.managed", [{"name": "/dup"}]),
  202. ("service.running", [{"name": "dup-service"}]),
  203. ]
  204. ),
  205. )
  206. ]
  207. ),
  208. )
  209. class RendererMixin(object):
  210. """
  211. This is a mixin that adds a ``.render()`` method to render a template
  212. It must come BEFORE ``TestCase`` in the declaration of your test case
  213. class so that our setUp & tearDown get invoked first, and super can
  214. trigger the methods in the ``TestCase`` class.
  215. """
  216. def setUp(self, *args, **kwargs):
  217. super(RendererMixin, self).setUp(*args, **kwargs)
  218. self.root_dir = tempfile.mkdtemp("pyobjects_test_root", dir=RUNTIME_VARS.TMP)
  219. self.state_tree_dir = os.path.join(self.root_dir, "state_tree")
  220. self.cache_dir = os.path.join(self.root_dir, "cachedir")
  221. if not os.path.isdir(self.root_dir):
  222. os.makedirs(self.root_dir)
  223. if not os.path.isdir(self.state_tree_dir):
  224. os.makedirs(self.state_tree_dir)
  225. if not os.path.isdir(self.cache_dir):
  226. os.makedirs(self.cache_dir)
  227. self.config = salt.config.minion_config(None)
  228. self.config["root_dir"] = self.root_dir
  229. self.config["state_events"] = False
  230. self.config["id"] = "match"
  231. self.config["file_client"] = "local"
  232. self.config["file_roots"] = dict(base=[self.state_tree_dir])
  233. self.config["cachedir"] = self.cache_dir
  234. self.config["test"] = False
  235. def tearDown(self, *args, **kwargs):
  236. shutil.rmtree(self.root_dir)
  237. del self.config
  238. super(RendererMixin, self).tearDown(*args, **kwargs)
  239. def write_template_file(self, filename, content):
  240. full_path = os.path.join(self.state_tree_dir, filename)
  241. with salt.utils.files.fopen(full_path, "w") as f:
  242. f.write(content)
  243. return full_path
  244. def render(self, template, opts=None, filename=None):
  245. if opts:
  246. self.config.update(opts)
  247. if not filename:
  248. filename = ".".join([str(uuid.uuid4()), "sls"])
  249. full_path = self.write_template_file(filename, template)
  250. state = salt.state.State(self.config)
  251. return compile_template(
  252. full_path,
  253. state.rend,
  254. state.opts["renderer"],
  255. state.opts["renderer_blacklist"],
  256. state.opts["renderer_whitelist"],
  257. )
  258. class RendererTests(RendererMixin, StateTests, MapBuilder):
  259. @slowTest
  260. def test_basic(self):
  261. ret = self.render(basic_template)
  262. self.assertEqual(
  263. ret,
  264. OrderedDict(
  265. [
  266. (
  267. "/tmp",
  268. {
  269. "file.directory": [
  270. {"group": "root"},
  271. {"mode": "1777"},
  272. {"owner": "root"},
  273. ]
  274. },
  275. ),
  276. ]
  277. ),
  278. )
  279. self.assertEqual(Registry.states, OrderedDict())
  280. @slowTest
  281. def test_invalid_function(self):
  282. def _test():
  283. self.render(invalid_template)
  284. self.assertRaises(InvalidFunction, _test)
  285. @slowTest
  286. def test_include(self):
  287. ret = self.render(include_template)
  288. self.assertEqual(ret, OrderedDict([("include", ["http"])]))
  289. @slowTest
  290. def test_extend(self):
  291. ret = self.render(
  292. extend_template, {"grains": {"os_family": "Debian", "os": "Debian"}}
  293. )
  294. self.assertEqual(
  295. ret,
  296. OrderedDict(
  297. [
  298. ("include", ["http"]),
  299. (
  300. "extend",
  301. OrderedDict(
  302. [
  303. (
  304. "apache",
  305. {
  306. "service.running": [
  307. {"watch": [{"file": "/etc/file"}]}
  308. ]
  309. },
  310. ),
  311. ]
  312. ),
  313. ),
  314. ]
  315. ),
  316. )
  317. @slowTest
  318. def test_sls_imports(self):
  319. def render_and_assert(template):
  320. ret = self.render(
  321. template, {"grains": {"os_family": "Debian", "os": "Debian"}}
  322. )
  323. self.assertEqual(
  324. ret,
  325. OrderedDict(
  326. [
  327. (
  328. "samba-imported",
  329. {"pkg.removed": [{"names": ["samba", "samba-client"]}]},
  330. )
  331. ]
  332. ),
  333. )
  334. self.write_template_file("map.sls", self.build_map())
  335. render_and_assert(import_template)
  336. render_and_assert(from_import_template)
  337. render_and_assert(import_as_template)
  338. self.write_template_file("recursive_map.sls", recursive_map_template)
  339. render_and_assert(recursive_import_template)
  340. @slowTest
  341. def test_import_scope(self):
  342. self.write_template_file("map.sls", self.build_map())
  343. self.write_template_file("recursive_map.sls", recursive_map_template)
  344. def do_render():
  345. ret = self.render(
  346. scope_test_import_template,
  347. {"grains": {"os_family": "Debian", "os": "Debian"}},
  348. )
  349. self.assertRaises(NameError, do_render)
  350. @slowTest
  351. def test_random_password(self):
  352. """Test for https://github.com/saltstack/salt/issues/21796"""
  353. ret = self.render(random_password_template)
  354. @slowTest
  355. def test_import_random_password(self):
  356. """Import test for https://github.com/saltstack/salt/issues/21796"""
  357. self.write_template_file("password.sls", random_password_template)
  358. ret = self.render(random_password_import_template)
  359. @slowTest
  360. def test_requisite_implicit_list(self):
  361. """Ensure that the implicit list characteristic works as expected"""
  362. ret = self.render(
  363. requisite_implicit_list_template,
  364. {"grains": {"os_family": "Debian", "os": "Debian"}},
  365. )
  366. self.assertEqual(
  367. ret,
  368. OrderedDict(
  369. [
  370. ("pkg", OrderedDict([("pkg.installed", [])])),
  371. (
  372. "service",
  373. OrderedDict(
  374. [
  375. (
  376. "service.running",
  377. [
  378. {"require": [{"cmd": "cmd"}, {"pkg": "pkg"}]},
  379. {"watch": [{"file": "file"}]},
  380. ],
  381. )
  382. ]
  383. ),
  384. ),
  385. ]
  386. ),
  387. )
  388. class MapTests(RendererMixin, TestCase, MapBuilder):
  389. maxDiff = None
  390. debian_grains = {"os_family": "Debian", "os": "Debian"}
  391. ubuntu_grains = {"os_family": "Debian", "os": "Ubuntu"}
  392. centos_grains = {"os_family": "RedHat", "os": "CentOS"}
  393. debian_attrs = ("samba", "samba-client", "samba")
  394. ubuntu_attrs = ("samba", "samba-client", "smbd")
  395. centos_attrs = ("samba", "samba", "smb")
  396. def samba_with_grains(self, template, grains):
  397. return self.render(template, {"grains": grains})
  398. def assert_equal(self, ret, server, client, service):
  399. self.assertDictEqual(
  400. ret,
  401. OrderedDict(
  402. [
  403. (
  404. "samba",
  405. OrderedDict(
  406. [
  407. ("pkg.installed", [{"names": [server, client]}]),
  408. (
  409. "service.running",
  410. [
  411. {"name": service},
  412. {"require": [{"pkg": "samba"}]},
  413. ],
  414. ),
  415. ]
  416. ),
  417. )
  418. ]
  419. ),
  420. )
  421. def assert_not_equal(self, ret, server, client, service):
  422. try:
  423. self.assert_equal(ret, server, client, service)
  424. except AssertionError:
  425. pass
  426. else:
  427. raise AssertionError("both dicts are equal")
  428. @slowTest
  429. def test_map(self):
  430. """
  431. Test declarative ordering
  432. """
  433. # With declarative ordering, the ubuntu-specific service name should
  434. # override the one inherited from debian.
  435. template = self.build_map(
  436. textwrap.dedent(
  437. """\
  438. {{ debian }}
  439. {{ centos }}
  440. {{ ubuntu }}
  441. """
  442. )
  443. )
  444. ret = self.samba_with_grains(template, self.debian_grains)
  445. self.assert_equal(ret, *self.debian_attrs)
  446. ret = self.samba_with_grains(template, self.ubuntu_grains)
  447. self.assert_equal(ret, *self.ubuntu_attrs)
  448. ret = self.samba_with_grains(template, self.centos_grains)
  449. self.assert_equal(ret, *self.centos_attrs)
  450. # Switching the order, debian should still work fine but ubuntu should
  451. # no longer match, since the debian service name should override the
  452. # ubuntu one.
  453. template = self.build_map(
  454. textwrap.dedent(
  455. """\
  456. {{ ubuntu }}
  457. {{ debian }}
  458. """
  459. )
  460. )
  461. ret = self.samba_with_grains(template, self.debian_grains)
  462. self.assert_equal(ret, *self.debian_attrs)
  463. ret = self.samba_with_grains(template, self.ubuntu_grains)
  464. self.assert_not_equal(ret, *self.ubuntu_attrs)
  465. @slowTest
  466. def test_map_with_priority(self):
  467. """
  468. With declarative ordering, the debian service name would override the
  469. ubuntu one since debian comes second. This will test overriding this
  470. behavior using the priority attribute.
  471. """
  472. template = self.build_map(
  473. textwrap.dedent(
  474. """\
  475. {{ priority(('os_family', 'os')) }}
  476. {{ ubuntu }}
  477. {{ centos }}
  478. {{ debian }}
  479. """
  480. )
  481. )
  482. ret = self.samba_with_grains(template, self.debian_grains)
  483. self.assert_equal(ret, *self.debian_attrs)
  484. ret = self.samba_with_grains(template, self.ubuntu_grains)
  485. self.assert_equal(ret, *self.ubuntu_attrs)
  486. ret = self.samba_with_grains(template, self.centos_grains)
  487. self.assert_equal(ret, *self.centos_attrs)
  488. class SaltObjectTests(TestCase):
  489. def test_salt_object(self):
  490. def attr_fail():
  491. Salt.fail.blah()
  492. def times2(x):
  493. return x * 2
  494. __salt__ = {"math.times2": times2}
  495. Salt = SaltObject(__salt__)
  496. self.assertRaises(AttributeError, attr_fail)
  497. self.assertEqual(Salt.math.times2, times2)
  498. self.assertEqual(Salt.math.times2(2), 4)