test_context.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # -*- coding: utf-8 -*-
  2. '''
  3. tests.unit.context_test
  4. ~~~~~~~~~~~~~~~~~~~~
  5. '''
  6. # Import python libs
  7. from __future__ import absolute_import
  8. import tornado.stack_context
  9. import tornado.gen
  10. from tornado.testing import AsyncTestCase, gen_test
  11. import threading
  12. import time
  13. # Import Salt Testing libs
  14. from tests.support.unit import TestCase
  15. from salt.ext.six.moves import range
  16. # Import Salt libs
  17. import salt.utils.json
  18. from salt.utils.context import ContextDict, NamespacedDictWrapper
  19. class ContextDictTests(AsyncTestCase):
  20. # how many threads/coroutines to run at a time
  21. num_concurrent_tasks = 5
  22. def setUp(self):
  23. super(ContextDictTests, self).setUp()
  24. self.cd = ContextDict()
  25. # set a global value
  26. self.cd['foo'] = 'global'
  27. def test_threads(self):
  28. '''Verify that ContextDict overrides properly within threads
  29. '''
  30. rets = []
  31. def tgt(x, s):
  32. inner_ret = []
  33. over = self.cd.clone()
  34. inner_ret.append(self.cd.get('foo'))
  35. with over:
  36. inner_ret.append(over.get('foo'))
  37. over['foo'] = x
  38. inner_ret.append(over.get('foo'))
  39. time.sleep(s)
  40. inner_ret.append(over.get('foo'))
  41. rets.append(inner_ret)
  42. threads = []
  43. for x in range(0, self.num_concurrent_tasks):
  44. s = self.num_concurrent_tasks - x
  45. t = threading.Thread(target=tgt, args=(x, s))
  46. t.start()
  47. threads.append(t)
  48. for t in threads:
  49. t.join()
  50. for r in rets:
  51. self.assertEqual(r[0], r[1])
  52. self.assertEqual(r[2], r[3])
  53. @gen_test
  54. def test_coroutines(self):
  55. '''Verify that ContextDict overrides properly within coroutines
  56. '''
  57. @tornado.gen.coroutine
  58. def secondary_coroutine(over):
  59. raise tornado.gen.Return(over.get('foo'))
  60. @tornado.gen.coroutine
  61. def tgt(x, s, over):
  62. inner_ret = []
  63. # first grab the global
  64. inner_ret.append(self.cd.get('foo'))
  65. # grab the child's global (should match)
  66. inner_ret.append(over.get('foo'))
  67. # override the global
  68. over['foo'] = x
  69. inner_ret.append(over.get('foo'))
  70. # sleep for some time to let other coroutines do this section of code
  71. yield tornado.gen.sleep(s)
  72. # get the value of the global again.
  73. inner_ret.append(over.get('foo'))
  74. # Call another coroutine to verify that we keep our context
  75. r = yield secondary_coroutine(over)
  76. inner_ret.append(r)
  77. raise tornado.gen.Return(inner_ret)
  78. futures = []
  79. for x in range(0, self.num_concurrent_tasks):
  80. s = self.num_concurrent_tasks - x
  81. over = self.cd.clone()
  82. f = tornado.stack_context.run_with_stack_context(
  83. tornado.stack_context.StackContext(lambda: over), # pylint: disable=W0640
  84. lambda: tgt(x, s/5.0, over), # pylint: disable=W0640
  85. )
  86. futures.append(f)
  87. wait_iterator = tornado.gen.WaitIterator(*futures)
  88. while not wait_iterator.done():
  89. r = yield wait_iterator.next() # pylint: disable=incompatible-py3-code
  90. self.assertEqual(r[0], r[1]) # verify that the global value remails
  91. self.assertEqual(r[2], r[3]) # verify that the override sticks locally
  92. self.assertEqual(r[3], r[4]) # verify that the override sticks across coroutines
  93. def test_basic(self):
  94. '''Test that the contextDict is a dict
  95. '''
  96. # ensure we get the global value
  97. self.assertEqual(
  98. dict(self.cd),
  99. {'foo': 'global'},
  100. )
  101. def test_override(self):
  102. over = self.cd.clone()
  103. over['bar'] = 'global'
  104. self.assertEqual(
  105. dict(over),
  106. {'foo': 'global', 'bar': 'global'},
  107. )
  108. self.assertEqual(
  109. dict(self.cd),
  110. {'foo': 'global'},
  111. )
  112. with over:
  113. self.assertEqual(
  114. dict(over),
  115. {'foo': 'global', 'bar': 'global'},
  116. )
  117. self.assertEqual(
  118. dict(self.cd),
  119. {'foo': 'global', 'bar': 'global'},
  120. )
  121. over['bar'] = 'baz'
  122. self.assertEqual(
  123. dict(over),
  124. {'foo': 'global', 'bar': 'baz'},
  125. )
  126. self.assertEqual(
  127. dict(self.cd),
  128. {'foo': 'global', 'bar': 'baz'},
  129. )
  130. self.assertEqual(
  131. dict(over),
  132. {'foo': 'global', 'bar': 'baz'},
  133. )
  134. self.assertEqual(
  135. dict(self.cd),
  136. {'foo': 'global'},
  137. )
  138. def test_multiple_contexts(self):
  139. cds = []
  140. for x in range(0, 10):
  141. cds.append(self.cd.clone(bar=x))
  142. for x, cd in enumerate(cds):
  143. self.assertNotIn('bar', self.cd)
  144. with cd:
  145. self.assertEqual(
  146. dict(self.cd),
  147. {'bar': x, 'foo': 'global'},
  148. )
  149. self.assertNotIn('bar', self.cd)
  150. class NamespacedDictWrapperTests(TestCase):
  151. PREFIX = 'prefix'
  152. def setUp(self):
  153. self._dict = {}
  154. def test_single_key(self):
  155. self._dict['prefix'] = {'foo': 'bar'}
  156. w = NamespacedDictWrapper(self._dict, 'prefix')
  157. self.assertEqual(w['foo'], 'bar')
  158. def test_multiple_key(self):
  159. self._dict['prefix'] = {'foo': {'bar': 'baz'}}
  160. w = NamespacedDictWrapper(self._dict, ('prefix', 'foo'))
  161. self.assertEqual(w['bar'], 'baz')
  162. def test_json_dumps_single_key(self):
  163. self._dict['prefix'] = {'foo': {'bar': 'baz'}}
  164. w = NamespacedDictWrapper(self._dict, 'prefix')
  165. self.assertEqual(salt.utils.json.dumps(w), '{"foo": {"bar": "baz"}}')
  166. def test_json_dumps_multiple_key(self):
  167. self._dict['prefix'] = {'foo': {'bar': 'baz'}}
  168. w = NamespacedDictWrapper(self._dict, ('prefix', 'foo'))
  169. self.assertEqual(salt.utils.json.dumps(w), '{"bar": "baz"}')