[pypy-svn] pypy default: Hack around, mostly importing the _PyImport_{Find, Fixup}Extension()

arigo commits-noreply at bitbucket.org
Mon Dec 20 17:18:08 CET 2010


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r40149:cd083843b67a
Date: 2010-12-20 17:17 +0100
http://bitbucket.org/pypy/pypy/changeset/cd083843b67a/

Log:	Hack around, mostly importing the _PyImport_{Find,Fixup}Extension()
	from CPython. Prevents the same extension module from being
	initialized twice.

diff --git a/pypy/module/cpyext/state.py b/pypy/module/cpyext/state.py
--- a/pypy/module/cpyext/state.py
+++ b/pypy/module/cpyext/state.py
@@ -16,14 +16,19 @@
         self.operror = None
         self.new_method_def = lltype.nullptr(PyMethodDef)
 
-        # When importing a package, use this to keep track of its name.  This is
+        # When importing a package, use this to keep track
+        # of its name and path (as a 2-tuple).  This is
         # necessary because an extension module in a package might not supply
         # its own fully qualified name to Py_InitModule.  If it doesn't, we need
         # to be able to figure out what module is being initialized.  Recursive
         # imports will clobber this value, which might be confusing, but it
         # doesn't hurt anything because the code that cares about it will have
         # already read it by that time.
-        self.package_context = None
+        self.package_context = None, None
+
+        # A mapping {filename: copy-of-the-w_dict}, similar to CPython's
+        # variable 'extensions' in Python/import.c.
+        self.extensions = {}
 
     def set_exception(self, operror):
         self.clear_exception()
@@ -96,3 +101,29 @@
             self.programname = rffi.str2charp(progname)
             lltype.render_immortal(self.programname)
         return self.programname
+
+    def find_extension(self, name, path):
+        from pypy.module.cpyext.modsupport import PyImport_AddModule
+        from pypy.interpreter.module import Module
+        try:
+            w_dict = self.extensions[path]
+        except KeyError:
+            return None
+        w_mod = PyImport_AddModule(self.space, name)
+        assert isinstance(w_mod, Module)
+        w_mdict = w_mod.getdict()
+        self.space.call_method(w_mdict, 'update', w_dict)
+        return w_mod
+
+    def fixup_extension(self, name, path):
+        from pypy.interpreter.module import Module
+        space = self.space
+        w_modules = space.sys.get('modules')
+        w_mod = space.finditem_str(w_modules, name)
+        if not isinstance(w_mod, Module):
+            msg = "fixup_extension: module '%s' not loaded" % name
+            raise OperationError(space.w_SystemError,
+                                 space.wrap(msg))
+        w_dict = w_mod.getdict()
+        w_copy = space.call_method(w_dict, 'copy')
+        self.extensions[path] = w_copy

diff --git a/pypy/module/cpyext/test/test_cpyext.py b/pypy/module/cpyext/test/test_cpyext.py
--- a/pypy/module/cpyext/test/test_cpyext.py
+++ b/pypy/module/cpyext/test/test_cpyext.py
@@ -222,6 +222,12 @@
         else:
             return os.path.dirname(mod)
 
+    def reimport_module(self, mod, name):
+        api.load_extension_module(self.space, mod, name)
+        return self.space.getitem(
+            self.space.sys.get('modules'),
+            self.space.wrap(name))
+
     def import_extension(self, modname, functions, prologue=""):
         methods_table = []
         codes = []
@@ -261,6 +267,7 @@
         self.imported_module_names = []
 
         self.w_import_module = self.space.wrap(self.import_module)
+        self.w_reimport_module = self.space.wrap(self.reimport_module)
         self.w_import_extension = self.space.wrap(self.import_extension)
         self.w_compile_module = self.space.wrap(self.compile_module)
         self.w_record_imported_module = self.space.wrap(
@@ -709,3 +716,43 @@
         p = mod.get_programname()
         print p
         assert 'py' in p
+
+    def test_no_double_imports(self):
+        import sys, os
+        try:
+            init = """
+            static int _imported_already = 0;
+            FILE *f = fopen("_imported_already", "w");
+            fprintf(f, "imported_already: %d\\n", _imported_already);
+            fclose(f);
+            _imported_already = 1;
+            if (Py_IsInitialized()) {
+                Py_InitModule("foo", NULL);
+            }
+            """
+            self.import_module(name='foo', init=init)
+            assert 'foo' in sys.modules
+
+            f = open('_imported_already')
+            data = f.read()
+            f.close()
+            assert data == 'imported_already: 0\n'
+
+            f = open('_imported_already', 'w')
+            f.write('not again!\n')
+            f.close()
+            m1 = sys.modules['foo']
+            m2 = self.reimport_module(m1.__file__, name='foo')
+            assert m1 is m2
+            assert m1 is sys.modules['foo']
+
+            f = open('_imported_already')
+            data = f.read()
+            f.close()
+            assert data == 'not again!\n'
+
+        finally:
+            try:
+                os.unlink('_imported_already')
+            except OSError:
+                pass

diff --git a/pypy/module/cpyext/modsupport.py b/pypy/module/cpyext/modsupport.py
--- a/pypy/module/cpyext/modsupport.py
+++ b/pypy/module/cpyext/modsupport.py
@@ -54,9 +54,10 @@
     from pypy.module.cpyext.typeobjectdefs import PyTypeObjectPtr
     modname = rffi.charp2str(name)
     state = space.fromcache(State)
-    w_mod = PyImport_AddModule(space, state.package_context)
+    f_name, f_path = state.package_context
+    w_mod = PyImport_AddModule(space, f_name)
 
-    dict_w = {}
+    dict_w = {'__file__': space.wrap(f_path)}
     convert_method_defs(space, dict_w, methods, None, w_self, modname)
     for key, w_value in dict_w.items():
         space.setattr(w_mod, space.wrap(key), w_value)

diff --git a/pypy/module/cpyext/api.py b/pypy/module/cpyext/api.py
--- a/pypy/module/cpyext/api.py
+++ b/pypy/module/cpyext/api.py
@@ -939,7 +939,9 @@
     if os.sep not in path:
         path = os.curdir + os.sep + path      # force a '/' in the path
     state = space.fromcache(State)
-    state.package_context = name
+    if state.find_extension(name, path) is not None:
+        return
+    state.package_context = name, path
     try:
         from pypy.rlib import rdynload
         try:
@@ -964,7 +966,8 @@
         generic_cpy_call(space, initfunc)
         state.check_and_raise_exception()
     finally:
-        state.package_context = None
+        state.package_context = None, None
+    state.fixup_extension(name, path)
 
 @specialize.ll()
 def generic_cpy_call(space, func, *args):


More information about the Pypy-commit mailing list