1
0

test_context.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # -*- coding: utf-8 -*-
  2. """
  3. tests.unit.context_test
  4. ~~~~~~~~~~~~~~~~~~~~
  5. """
  6. from __future__ import absolute_import
  7. import threading
  8. import time
  9. import salt.ext.tornado.gen
  10. import salt.ext.tornado.stack_context
  11. import salt.utils.json
  12. from salt.ext.six.moves import range
  13. from salt.ext.tornado.testing import AsyncTestCase, gen_test
  14. from salt.utils.context import ContextDict, NamespacedDictWrapper
  15. from tests.support.helpers import slowTest
  16. from tests.support.unit import TestCase
  17. class ContextDictTests(AsyncTestCase):
  18. # how many threads/coroutines to run at a time
  19. num_concurrent_tasks = 5
  20. def setUp(self):
  21. super(ContextDictTests, self).setUp()
  22. self.cd = ContextDict()
  23. # set a global value
  24. self.cd["foo"] = "global"
  25. @slowTest
  26. def test_threads(self):
  27. """Verify that ContextDict overrides properly within threads
  28. """
  29. rets = []
  30. def tgt(x, s):
  31. inner_ret = []
  32. over = self.cd.clone()
  33. inner_ret.append(self.cd.get("foo"))
  34. with over:
  35. inner_ret.append(over.get("foo"))
  36. over["foo"] = x
  37. inner_ret.append(over.get("foo"))
  38. time.sleep(s)
  39. inner_ret.append(over.get("foo"))
  40. rets.append(inner_ret)
  41. threads = []
  42. for x in range(0, self.num_concurrent_tasks):
  43. s = self.num_concurrent_tasks - x
  44. t = threading.Thread(target=tgt, args=(x, s))
  45. t.start()
  46. threads.append(t)
  47. for t in threads:
  48. t.join()
  49. for r in rets:
  50. self.assertEqual(r[0], r[1])
  51. self.assertEqual(r[2], r[3])
  52. @gen_test
  53. @slowTest
  54. def test_coroutines(self):
  55. """Verify that ContextDict overrides properly within coroutines
  56. """
  57. @salt.ext.tornado.gen.coroutine
  58. def secondary_coroutine(over):
  59. raise salt.ext.tornado.gen.Return(over.get("foo"))
  60. @salt.ext.tornado.gen.coroutine
  61. def tgt(x, s, over):
  62. inner_ret = []
  63. # first grab the global
  64. inner_ret.append(self.cd.get("foo"))
  65. # grab the child's global (should match)
  66. inner_ret.append(over.get("foo"))
  67. # override the global
  68. over["foo"] = x
  69. inner_ret.append(over.get("foo"))
  70. # sleep for some time to let other coroutines do this section of code
  71. yield salt.ext.tornado.gen.sleep(s)
  72. # get the value of the global again.
  73. inner_ret.append(over.get("foo"))
  74. # Call another coroutine to verify that we keep our context
  75. r = yield secondary_coroutine(over)
  76. inner_ret.append(r)
  77. raise salt.ext.tornado.gen.Return(inner_ret)
  78. futures = []
  79. for x in range(0, self.num_concurrent_tasks):
  80. s = self.num_concurrent_tasks - x
  81. over = self.cd.clone()
  82. # pylint: disable=cell-var-from-loop
  83. f = salt.ext.tornado.stack_context.run_with_stack_context(
  84. salt.ext.tornado.stack_context.StackContext(lambda: over),
  85. lambda: tgt(x, s / 5.0, over),
  86. )
  87. # pylint: enable=cell-var-from-loop
  88. futures.append(f)
  89. wait_iterator = salt.ext.tornado.gen.WaitIterator(*futures)
  90. while not wait_iterator.done():
  91. r = yield wait_iterator.next() # pylint: disable=incompatible-py3-code
  92. self.assertEqual(r[0], r[1]) # verify that the global value remails
  93. self.assertEqual(r[2], r[3]) # verify that the override sticks locally
  94. self.assertEqual(
  95. r[3], r[4]
  96. ) # verify that the override sticks across coroutines
  97. def test_basic(self):
  98. """Test that the contextDict is a dict
  99. """
  100. # ensure we get the global value
  101. self.assertEqual(
  102. dict(self.cd), {"foo": "global"},
  103. )
  104. def test_override(self):
  105. over = self.cd.clone()
  106. over["bar"] = "global"
  107. self.assertEqual(
  108. dict(over), {"foo": "global", "bar": "global"},
  109. )
  110. self.assertEqual(
  111. dict(self.cd), {"foo": "global"},
  112. )
  113. with over:
  114. self.assertEqual(
  115. dict(over), {"foo": "global", "bar": "global"},
  116. )
  117. self.assertEqual(
  118. dict(self.cd), {"foo": "global", "bar": "global"},
  119. )
  120. over["bar"] = "baz"
  121. self.assertEqual(
  122. dict(over), {"foo": "global", "bar": "baz"},
  123. )
  124. self.assertEqual(
  125. dict(self.cd), {"foo": "global", "bar": "baz"},
  126. )
  127. self.assertEqual(
  128. dict(over), {"foo": "global", "bar": "baz"},
  129. )
  130. self.assertEqual(
  131. dict(self.cd), {"foo": "global"},
  132. )
  133. def test_multiple_contexts(self):
  134. cds = []
  135. for x in range(0, 10):
  136. cds.append(self.cd.clone(bar=x))
  137. for x, cd in enumerate(cds):
  138. self.assertNotIn("bar", self.cd)
  139. with cd:
  140. self.assertEqual(
  141. dict(self.cd), {"bar": x, "foo": "global"},
  142. )
  143. self.assertNotIn("bar", self.cd)
  144. class NamespacedDictWrapperTests(TestCase):
  145. PREFIX = "prefix"
  146. def setUp(self):
  147. self._dict = {}
  148. def test_single_key(self):
  149. self._dict["prefix"] = {"foo": "bar"}
  150. w = NamespacedDictWrapper(self._dict, "prefix")
  151. self.assertEqual(w["foo"], "bar")
  152. def test_multiple_key(self):
  153. self._dict["prefix"] = {"foo": {"bar": "baz"}}
  154. w = NamespacedDictWrapper(self._dict, ("prefix", "foo"))
  155. self.assertEqual(w["bar"], "baz")
  156. def test_json_dumps_single_key(self):
  157. self._dict["prefix"] = {"foo": {"bar": "baz"}}
  158. w = NamespacedDictWrapper(self._dict, "prefix")
  159. self.assertEqual(salt.utils.json.dumps(w), '{"foo": {"bar": "baz"}}')
  160. def test_json_dumps_multiple_key(self):
  161. self._dict["prefix"] = {"foo": {"bar": "baz"}}
  162. w = NamespacedDictWrapper(self._dict, ("prefix", "foo"))
  163. self.assertEqual(salt.utils.json.dumps(w), '{"bar": "baz"}')