[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)

https://github.com/python/cpython/commit/7873884d4730d7e637a968011b8958bd79f... commit: 7873884d4730d7e637a968011b8958bd79fd3398 branch: 3.10 author: Serhiy Storchaka <storchaka@gmail.com> committer: serhiy-storchaka <storchaka@gmail.com> date: 2021-09-30T19:56:41+03:00 summary: [3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657) * Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module(). (cherry picked from commit ec4d917a6a68824f1895f75d113add9410283da7) files: A Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst M Lib/test/support/import_helper.py M Lib/test/test_decimal.py M Lib/test/test_xml_etree.py diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 5d1e9406879cc..43ae31483420d 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()): raise unittest.SkipTest(str(msg)) -def _save_and_remove_module(name, orig_modules): - """Helper function to save and remove a module from sys.modules - - Raise ImportError if the module can't be imported. - """ - # try to import the module and raise an error if it can't be imported - if name not in sys.modules: - __import__(name) - del sys.modules[name] +def _save_and_remove_modules(names): + orig_modules = {} + prefixes = tuple(name + '.' for name in names) for modname in list(sys.modules): - if modname == name or modname.startswith(name + '.'): - orig_modules[modname] = sys.modules[modname] - del sys.modules[modname] - - -def _save_and_block_module(name, orig_modules): - """Helper function to save and block a module in sys.modules - - Return True if the module was in sys.modules, False otherwise. - """ - saved = True - try: - orig_modules[name] = sys.modules[name] - except KeyError: - saved = False - sys.modules[name] = None - return saved + if modname in names or modname.startswith(prefixes): + orig_modules[modname] = sys.modules.pop(modname) + return orig_modules def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): @@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): this operation. *fresh* is an iterable of additional module names that are also removed - from the sys.modules cache before doing the import. + from the sys.modules cache before doing the import. If one of these + modules can't be imported, None is returned. *blocked* is an iterable of module names that are replaced with None in the module cache during the import to ensure that attempts to import @@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): with _ignore_deprecated_imports(deprecated): # Keep track of modules saved for later restoration as well # as those which just need a blocking entry removed - orig_modules = {} - names_to_remove = [] - _save_and_remove_module(name, orig_modules) + fresh = list(fresh) + blocked = list(blocked) + names = {name, *fresh, *blocked} + orig_modules = _save_and_remove_modules(names) + for modname in blocked: + sys.modules[modname] = None + try: - for fresh_name in fresh: - _save_and_remove_module(fresh_name, orig_modules) - for blocked_name in blocked: - if not _save_and_block_module(blocked_name, orig_modules): - names_to_remove.append(blocked_name) - fresh_module = importlib.import_module(name) - except ImportError: - fresh_module = None + # Return None when one of the "fresh" modules can not be imported. + try: + for modname in fresh: + __import__(modname) + except ImportError: + return None + return importlib.import_module(name) finally: - for orig_name, module in orig_modules.items(): - sys.modules[orig_name] = module - for name_to_remove in names_to_remove: - del sys.modules[name_to_remove] - return fresh_module + _save_and_remove_modules(names) + sys.modules.update(orig_modules) class CleanImport(object): diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 99263bb13b0d1..b6173a5ffec96 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -62,7 +62,7 @@ C = import_fresh_module('decimal', fresh=['_decimal']) P = import_fresh_module('decimal', blocked=['_decimal']) -orig_sys_decimal = sys.modules['decimal'] +import decimal as orig_sys_decimal # fractions module must import the correct decimal module. cfractions = import_fresh_module('fractions', fresh=['fractions']) diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index c79b5462b931d..5a8824a78ffa4 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -26,7 +26,7 @@ from test import support from test.support import os_helper from test.support import warnings_helper -from test.support import findfile, gc_collect, swap_attr +from test.support import findfile, gc_collect, swap_attr, swap_item from test.support.import_helper import import_fresh_module from test.support.os_helper import TESTFN @@ -167,12 +167,11 @@ def setUpClass(cls): cls.modules = {pyET, ET} def pickleRoundTrip(self, obj, name, dumper, loader, proto): - save_m = sys.modules[name] try: - sys.modules[name] = dumper - temp = pickle.dumps(obj, proto) - sys.modules[name] = loader - result = pickle.loads(temp) + with swap_item(sys.modules, name, dumper): + temp = pickle.dumps(obj, proto) + with swap_item(sys.modules, name, loader): + result = pickle.loads(temp) except pickle.PicklingError as pe: # pyET must be second, because pyET may be (equal to) ET. human = dict([(ET, "cET"), (pyET, "pyET")]) @@ -180,8 +179,6 @@ def pickleRoundTrip(self, obj, name, dumper, loader, proto): % (obj, human.get(dumper, dumper), human.get(loader, loader))) from pe - finally: - sys.modules[name] = save_m return result def assertEqualElements(self, alice, bob): diff --git a/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst new file mode 100644 index 0000000000000..21671473c16cc --- /dev/null +++ b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst @@ -0,0 +1,2 @@ +Fix :func:`test.support.import_helper.import_fresh_module`. +
participants (1)
-
serhiy-storchaka