helpers.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # -*- coding: utf-8 -*-
  2. """
  3. tests.support.pytest.helpers
  4. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  5. PyTest helpers functions
  6. """
  7. import logging
  8. import os
  9. import shutil
  10. import tempfile
  11. import textwrap
  12. import types
  13. from contextlib import contextmanager
  14. import pytest
  15. import salt.utils.files
  16. from tests.support.pytest.loader import LoaderModuleMock
  17. from tests.support.runtests import RUNTIME_VARS
  18. log = logging.getLogger(__name__)
  19. @pytest.helpers.register
  20. @contextmanager
  21. def temp_directory(name=None):
  22. """
  23. This helper creates a temporary directory. It should be used as a context manager
  24. which returns the temporary directory path, and, once out of context, deletes it.
  25. Can be directly imported and used, or, it can be used as a pytest helper function if
  26. ``pytest-helpers-namespace`` is installed.
  27. .. code-block:: python
  28. import os
  29. import pytest
  30. def test_blah():
  31. with pytest.helpers.temp_directory() as tpath:
  32. print(tpath)
  33. assert os.path.exists(tpath)
  34. assert not os.path.exists(tpath)
  35. """
  36. try:
  37. if name is not None:
  38. directory_path = os.path.join(RUNTIME_VARS.TMP, name)
  39. else:
  40. directory_path = tempfile.mkdtemp(dir=RUNTIME_VARS.TMP)
  41. if not os.path.isdir(directory_path):
  42. os.makedirs(directory_path)
  43. yield directory_path
  44. finally:
  45. shutil.rmtree(directory_path, ignore_errors=True)
  46. @pytest.helpers.register
  47. @contextmanager
  48. def temp_file(name=None, contents=None, directory=None, strip_first_newline=True):
  49. """
  50. This helper creates a temporary file. It should be used as a context manager
  51. which returns the temporary file path, and, once out of context, deletes it.
  52. Can be directly imported and used, or, it can be used as a pytest helper function if
  53. ``pytest-helpers-namespace`` is installed.
  54. .. code-block:: python
  55. import os
  56. import pytest
  57. def test_blah():
  58. with pytest.helpers.temp_file("blah.txt") as tpath:
  59. print(tpath)
  60. assert os.path.exists(tpath)
  61. assert not os.path.exists(tpath)
  62. Args:
  63. name(str):
  64. The temporary file name
  65. contents(str):
  66. The contents of the temporary file
  67. directory(str):
  68. The directory where to create the temporary file. If ``None``, then ``RUNTIME_VARS.TMP``
  69. will be used.
  70. strip_first_newline(bool):
  71. Wether to strip the initial first new line char or not.
  72. """
  73. try:
  74. if directory is None:
  75. directory = RUNTIME_VARS.TMP
  76. if name is not None:
  77. file_path = os.path.join(directory, name)
  78. else:
  79. handle, file_path = tempfile.mkstemp(dir=directory)
  80. os.close(handle)
  81. file_directory = os.path.dirname(file_path)
  82. if not os.path.isdir(file_directory):
  83. os.makedirs(file_directory)
  84. if contents is not None:
  85. if contents:
  86. if contents.startswith("\n") and strip_first_newline:
  87. contents = contents[1:]
  88. file_contents = textwrap.dedent(contents)
  89. else:
  90. file_contents = contents
  91. with salt.utils.files.fopen(file_path, "w") as wfh:
  92. wfh.write(file_contents)
  93. yield file_path
  94. finally:
  95. try:
  96. os.unlink(file_path)
  97. except OSError:
  98. # Already deleted
  99. pass
  100. @pytest.helpers.register
  101. def temp_state_file(name, contents, saltenv="base", strip_first_newline=True):
  102. """
  103. This helper creates a temporary state file. It should be used as a context manager
  104. which returns the temporary state file path, and, once out of context, deletes it.
  105. Can be directly imported and used, or, it can be used as a pytest helper function if
  106. ``pytest-helpers-namespace`` is installed.
  107. .. code-block:: python
  108. import os
  109. import pytest
  110. def test_blah():
  111. with pytest.helpers.temp_state_file("blah.sls") as tpath:
  112. print(tpath)
  113. assert os.path.exists(tpath)
  114. assert not os.path.exists(tpath)
  115. Depending on the saltenv, it will be created under ``RUNTIME_VARS.TMP_STATE_TREE`` or
  116. ``RUNTIME_VARS.TMP_PRODENV_STATE_TREE``.
  117. Args:
  118. name(str):
  119. The temporary state file name
  120. contents(str):
  121. The contents of the temporary file
  122. saltenv(str):
  123. The salt env to use. Either ``base`` or ``prod``
  124. strip_first_newline(bool):
  125. Wether to strip the initial first new line char or not.
  126. """
  127. if saltenv == "base":
  128. directory = RUNTIME_VARS.TMP_STATE_TREE
  129. elif saltenv == "prod":
  130. directory = RUNTIME_VARS.TMP_PRODENV_STATE_TREE
  131. else:
  132. raise RuntimeError(
  133. '"saltenv" can only be "base" or "prod", not "{}"'.format(saltenv)
  134. )
  135. return temp_file(
  136. name, contents, directory=directory, strip_first_newline=strip_first_newline
  137. )
  138. @pytest.helpers.register
  139. def temp_pillar_file(name, contents, saltenv="base", strip_first_newline=True):
  140. """
  141. This helper creates a temporary pillar file. It should be used as a context manager
  142. which returns the temporary pillar file path, and, once out of context, deletes it.
  143. Can be directly imported and used, or, it can be used as a pytest helper function if
  144. ``pytest-helpers-namespace`` is installed.
  145. .. code-block:: python
  146. import os
  147. import pytest
  148. def test_blah():
  149. with pytest.helpers.temp_pillar_file("blah.sls") as tpath:
  150. print(tpath)
  151. assert os.path.exists(tpath)
  152. assert not os.path.exists(tpath)
  153. Depending on the saltenv, it will be created under ``RUNTIME_VARS.TMP_PILLAR_TREE`` or
  154. ``RUNTIME_VARS.TMP_PRODENV_PILLAR_TREE``.
  155. Args:
  156. name(str):
  157. The temporary state file name
  158. contents(str):
  159. The contents of the temporary file
  160. saltenv(str):
  161. The salt env to use. Either ``base`` or ``prod``
  162. strip_first_newline(bool):
  163. Wether to strip the initial first new line char or not.
  164. """
  165. if saltenv == "base":
  166. directory = RUNTIME_VARS.TMP_PILLAR_TREE
  167. elif saltenv == "prod":
  168. directory = RUNTIME_VARS.TMP_PRODENV_PILLAR_TREE
  169. else:
  170. raise RuntimeError(
  171. '"saltenv" can only be "base" or "prod", not "{}"'.format(saltenv)
  172. )
  173. return temp_file(
  174. name, contents, directory=directory, strip_first_newline=strip_first_newline
  175. )
  176. @pytest.helpers.register
  177. def loader_mock(request, loader_modules, **kwargs):
  178. return LoaderModuleMock(request, loader_modules, **kwargs)
  179. @pytest.helpers.register
  180. def salt_loader_module_functions(module):
  181. if not isinstance(module, types.ModuleType):
  182. raise RuntimeError(
  183. "The passed 'module' argument must be an imported "
  184. "imported module, not {}".format(type(module))
  185. )
  186. funcs = {}
  187. func_alias = getattr(module, "__func_alias__", {})
  188. virtualname = getattr(module, "__virtualname__")
  189. for name in dir(module):
  190. if name.startswith("_"):
  191. continue
  192. func = getattr(module, name)
  193. if getattr(func, "__module__", None) != module.__name__:
  194. # Not eve defined on the module being processed, carry on
  195. continue
  196. if not isinstance(func, types.FunctionType):
  197. # Not a function? carry on
  198. continue
  199. funcname = func_alias.get(func.__name__) or func.__name__
  200. funcs["{}.{}".format(virtualname, funcname)] = func
  201. return funcs
  202. # Only allow star importing the functions defined in this module
  203. __all__ = [
  204. name
  205. for (name, func) in locals().items()
  206. if getattr(func, "__module__", None) == __name__
  207. ]