test_virtualname.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """
  2. tests.unit.test_virtualname
  3. ~~~~~~~~~~~~~~~~~~~~
  4. """
  5. import importlib.util
  6. import logging
  7. import os
  8. from tests.support.runtests import RUNTIME_VARS
  9. from tests.support.unit import TestCase
  10. log = logging.getLogger(__name__)
  11. class FakeEntry:
  12. def __init__(self, name, path, is_file=True):
  13. self.name = name
  14. self.path = path
  15. self._is_file = is_file
  16. def is_file(self):
  17. return self._is_file
  18. class VirtualNameTestCase(TestCase):
  19. """
  20. Test that the virtualname is in the module name, to speed up lookup of
  21. modules.
  22. """
  23. maxDiff = None
  24. @staticmethod
  25. def _import_module(testpath):
  26. spec = importlib.util.spec_from_file_location("tmpmodule", testpath)
  27. module = importlib.util.module_from_spec(spec)
  28. spec.loader.exec_module(module)
  29. return module
  30. def _check_modules(self, path):
  31. """
  32. check modules in directory
  33. """
  34. ret = []
  35. for entry in os.listdir(path):
  36. name, path = os.path.splitext(os.path.basename(entry))[0], entry
  37. if name.startswith(".") or name.startswith("_"):
  38. continue
  39. if os.path.isfile(path) and not name.endswith(".py"):
  40. continue
  41. testpath = (
  42. path if os.path.isfile(path) else os.path.join(path, "__init__.py")
  43. )
  44. module = self._import_module(testpath)
  45. if hasattr(module, "__virtualname__"):
  46. if module.__virtualname__ not in name:
  47. ret.append(
  48. 'Virtual name "{}" is not in the module filename "{}": {}'.format(
  49. module.__virtualname__, name, path
  50. )
  51. )
  52. return ret
  53. def test_check_virtualname(self):
  54. """
  55. Test that the virtualname is in __name__ of the module
  56. """
  57. errors = []
  58. for entry in os.listdir(RUNTIME_VARS.SALT_CODE_DIR):
  59. name, path = os.path.splitext(os.path.basename(entry))[0], entry
  60. if name.startswith(".") or name.startswith("_") or not os.path.isdir(path):
  61. continue
  62. if name in ("cli", "defaults", "spm", "daemons", "ext", "templates"):
  63. continue
  64. if name == "cloud":
  65. entry = os.path.join(RUNTIME_VARS.SALT_CODE_DIR, "cloud", "clouds")
  66. errors.extend(self._check_modules(entry))
  67. for error in errors:
  68. log.critical(error)
  69. assert not errors