[pypy-commit] pypy default: hg merge ec-threadlocal

arigo noreply at buildbot.pypy.org
Mon Jun 23 20:01:27 CEST 2014


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r72171:abe70e8eeca9
Date: 2014-06-23 19:59 +0200
http://bitbucket.org/pypy/pypy/changeset/abe70e8eeca9/

Log:	hg merge ec-threadlocal

	Change the executioncontext's lookup to be done by reading a thread-
	local variable (which is implemented in C using '__thread' if
	possible, and pthread_getspecific() otherwise). On Linux x86 and
	x86-64, the JIT backend has a special optimization that lets it emit
	directly a single MOV from a %gs- or %fs-based address. It seems
	actually to give a good boost in performance.

diff --git a/pypy/goal/targetpypystandalone.py b/pypy/goal/targetpypystandalone.py
--- a/pypy/goal/targetpypystandalone.py
+++ b/pypy/goal/targetpypystandalone.py
@@ -30,8 +30,6 @@
     if w_dict is not None: # for tests
         w_entry_point = space.getitem(w_dict, space.wrap('entry_point'))
         w_run_toplevel = space.getitem(w_dict, space.wrap('run_toplevel'))
-        w_call_finish_gateway = space.wrap(gateway.interp2app(call_finish))
-        w_call_startup_gateway = space.wrap(gateway.interp2app(call_startup))
         withjit = space.config.objspace.usemodules.pypyjit
 
     def entry_point(argv):
@@ -53,7 +51,7 @@
             argv = argv[:1] + argv[3:]
         try:
             try:
-                space.call_function(w_run_toplevel, w_call_startup_gateway)
+                space.startup()
                 w_executable = space.wrap(argv[0])
                 w_argv = space.newlist([space.wrap(s) for s in argv[1:]])
                 w_exitcode = space.call_function(w_entry_point, w_executable, w_argv)
@@ -69,7 +67,7 @@
                 return 1
         finally:
             try:
-                space.call_function(w_run_toplevel, w_call_finish_gateway)
+                space.finish()
             except OperationError, e:
                 debug("OperationError:")
                 debug(" operror-type: " + e.w_type.getname(space))
@@ -184,11 +182,6 @@
                          'pypy_thread_attach': pypy_thread_attach,
                          'pypy_setup_home': pypy_setup_home}
 
-def call_finish(space):
-    space.finish()
-
-def call_startup(space):
-    space.startup()
 
 # _____ Define and setup target ___
 
diff --git a/pypy/interpreter/baseobjspace.py b/pypy/interpreter/baseobjspace.py
--- a/pypy/interpreter/baseobjspace.py
+++ b/pypy/interpreter/baseobjspace.py
@@ -395,6 +395,7 @@
 
     def startup(self):
         # To be called before using the space
+        self.threadlocals.enter_thread(self)
 
         # Initialize already imported builtin modules
         from pypy.interpreter.module import Module
@@ -639,30 +640,33 @@
         """NOT_RPYTHON: Abstract method that should put some minimal
         content into the w_builtins."""
 
-    @jit.loop_invariant
     def getexecutioncontext(self):
         "Return what we consider to be the active execution context."
         # Important: the annotator must not see a prebuilt ExecutionContext:
         # you should not see frames while you translate
         # so we make sure that the threadlocals never *have* an
         # ExecutionContext during translation.
-        if self.config.translating and not we_are_translated():
-            assert self.threadlocals.getvalue() is None, (
-                "threadlocals got an ExecutionContext during translation!")
-            try:
-                return self._ec_during_translation
-            except AttributeError:
-                ec = self.createexecutioncontext()
-                self._ec_during_translation = ec
+        if not we_are_translated():
+            if self.config.translating:
+                assert self.threadlocals.get_ec() is None, (
+                    "threadlocals got an ExecutionContext during translation!")
+                try:
+                    return self._ec_during_translation
+                except AttributeError:
+                    ec = self.createexecutioncontext()
+                    self._ec_during_translation = ec
+                    return ec
+            else:
+                ec = self.threadlocals.get_ec()
+                if ec is None:
+                    self.threadlocals.enter_thread(self)
+                    ec = self.threadlocals.get_ec()
                 return ec
-        # normal case follows.  The 'thread' module installs a real
-        # thread-local object in self.threadlocals, so this builds
-        # and caches a new ec in each thread.
-        ec = self.threadlocals.getvalue()
-        if ec is None:
-            ec = self.createexecutioncontext()
-            self.threadlocals.setvalue(ec)
-        return ec
+        else:
+            # translated case follows.  self.threadlocals is either from
+            # 'pypy.interpreter.miscutils' or 'pypy.module.thread.threadlocals'.
+            # the result is assumed to be non-null: enter_thread() was called.
+            return self.threadlocals.get_ec()
 
     def _freeze_(self):
         return True
diff --git a/pypy/interpreter/miscutils.py b/pypy/interpreter/miscutils.py
--- a/pypy/interpreter/miscutils.py
+++ b/pypy/interpreter/miscutils.py
@@ -11,11 +11,11 @@
     """
     _value = None
 
-    def getvalue(self):
+    def get_ec(self):
         return self._value
 
-    def setvalue(self, value):
-        self._value = value
+    def enter_thread(self, space):
+        self._value = space.createexecutioncontext()
 
     def signals_enabled(self):
         return True
diff --git a/pypy/module/thread/__init__.py b/pypy/module/thread/__init__.py
--- a/pypy/module/thread/__init__.py
+++ b/pypy/module/thread/__init__.py
@@ -26,10 +26,11 @@
         "NOT_RPYTHON: patches space.threadlocals to use real threadlocals"
         from pypy.module.thread import gil
         MixedModule.__init__(self, space, *args)
-        prev = space.threadlocals.getvalue()
+        prev_ec = space.threadlocals.get_ec()
         space.threadlocals = gil.GILThreadLocals()
         space.threadlocals.initialize(space)
-        space.threadlocals.setvalue(prev)
+        if prev_ec is not None:
+            space.threadlocals._set_ec(prev_ec)
 
         from pypy.module.posix.interp_posix import add_fork_hook
         from pypy.module.thread.os_thread import reinit_threads
diff --git a/pypy/module/thread/os_thread.py b/pypy/module/thread/os_thread.py
--- a/pypy/module/thread/os_thread.py
+++ b/pypy/module/thread/os_thread.py
@@ -126,6 +126,8 @@
     release = staticmethod(release)
 
     def run(space, w_callable, args):
+        # add the ExecutionContext to space.threadlocals
+        space.threadlocals.enter_thread(space)
         try:
             space.call_args(w_callable, args)
         except OperationError, e:
diff --git a/pypy/module/thread/test/test_gil.py b/pypy/module/thread/test/test_gil.py
--- a/pypy/module/thread/test/test_gil.py
+++ b/pypy/module/thread/test/test_gil.py
@@ -64,13 +64,14 @@
             except Exception, e:
                 assert 0
             thread.gc_thread_die()
+        my_gil_threadlocals = gil.GILThreadLocals()
         def f():
             state.data = []
             state.datalen1 = 0
             state.datalen2 = 0
             state.datalen3 = 0
             state.datalen4 = 0
-            state.threadlocals = gil.GILThreadLocals()
+            state.threadlocals = my_gil_threadlocals
             state.threadlocals.setup_threads(space)
             subident = thread.start_new_thread(bootstrap, ())
             mainident = thread.get_ident()
diff --git a/pypy/module/thread/threadlocals.py b/pypy/module/thread/threadlocals.py
--- a/pypy/module/thread/threadlocals.py
+++ b/pypy/module/thread/threadlocals.py
@@ -1,4 +1,5 @@
 from rpython.rlib import rthread
+from rpython.rlib.objectmodel import we_are_translated
 from pypy.module.thread.error import wrap_thread_error
 from pypy.interpreter.executioncontext import ExecutionContext
 
@@ -13,53 +14,62 @@
     os_thread.bootstrap()."""
 
     def __init__(self):
+        "NOT_RPYTHON"
         self._valuedict = {}   # {thread_ident: ExecutionContext()}
         self._cleanup_()
+        self.raw_thread_local = rthread.ThreadLocalReference(ExecutionContext)
 
     def _cleanup_(self):
         self._valuedict.clear()
         self._mainthreadident = 0
-        self._mostrecentkey = 0        # fast minicaching for the common case
-        self._mostrecentvalue = None   # fast minicaching for the common case
 
-    def getvalue(self):
+    def enter_thread(self, space):
+        "Notification that the current thread is about to start running."
+        self._set_ec(space.createexecutioncontext())
+
+    def _set_ec(self, ec):
         ident = rthread.get_ident()
-        if ident == self._mostrecentkey:
-            result = self._mostrecentvalue
-        else:
-            value = self._valuedict.get(ident, None)
-            # slow path: update the minicache
-            self._mostrecentkey = ident
-            self._mostrecentvalue = value
-            result = value
-        return result
+        if self._mainthreadident == 0 or self._mainthreadident == ident:
+            ec._signals_enabled = 1    # the main thread is enabled
+            self._mainthreadident = ident
+        self._valuedict[ident] = ec
+        # This logic relies on hacks and _make_sure_does_not_move().
+        # It only works because we keep the 'ec' alive in '_valuedict' too.
+        self.raw_thread_local.set(ec)
 
-    def setvalue(self, value):
-        ident = rthread.get_ident()
-        if value is not None:
-            if self._mainthreadident == 0:
-                value._signals_enabled = 1    # the main thread is enabled
-                self._mainthreadident = ident
-            self._valuedict[ident] = value
-        else:
+    def leave_thread(self, space):
+        "Notification that the current thread is about to stop."
+        from pypy.module.thread.os_local import thread_is_stopping
+        ec = self.get_ec()
+        if ec is not None:
             try:
-                del self._valuedict[ident]
-            except KeyError:
-                pass
-        # update the minicache to prevent it from containing an outdated value
-        self._mostrecentkey = ident
-        self._mostrecentvalue = value
+                thread_is_stopping(ec)
+            finally:
+                self.raw_thread_local.set(None)
+                ident = rthread.get_ident()
+                try:
+                    del self._valuedict[ident]
+                except KeyError:
+                    pass
+
+    def get_ec(self):
+        ec = self.raw_thread_local.get()
+        if not we_are_translated():
+            assert ec is self._valuedict.get(rthread.get_ident(), None)
+        return ec
 
     def signals_enabled(self):
-        ec = self.getvalue()
+        ec = self.get_ec()
         return ec is not None and ec._signals_enabled
 
     def enable_signals(self, space):
-        ec = self.getvalue()
+        ec = self.get_ec()
+        assert ec is not None
         ec._signals_enabled += 1
 
     def disable_signals(self, space):
-        ec = self.getvalue()
+        ec = self.get_ec()
+        assert ec is not None
         new = ec._signals_enabled - 1
         if new < 0:
             raise wrap_thread_error(space,
@@ -69,22 +79,15 @@
     def getallvalues(self):
         return self._valuedict
 
-    def leave_thread(self, space):
-        "Notification that the current thread is about to stop."
-        from pypy.module.thread.os_local import thread_is_stopping
-        ec = self.getvalue()
-        if ec is not None:
-            try:
-                thread_is_stopping(ec)
-            finally:
-                self.setvalue(None)
-
     def reinit_threads(self, space):
         "Called in the child process after a fork()"
         ident = rthread.get_ident()
-        ec = self.getvalue()
+        ec = self.get_ec()
+        assert ec is not None
+        old_sig = ec._signals_enabled
         if ident != self._mainthreadident:
-            ec._signals_enabled += 1
+            old_sig += 1
         self._cleanup_()
         self._mainthreadident = ident
-        self.setvalue(ec)
+        self._set_ec(ec)
+        ec._signals_enabled = old_sig
diff --git a/pypy/objspace/fake/objspace.py b/pypy/objspace/fake/objspace.py
--- a/pypy/objspace/fake/objspace.py
+++ b/pypy/objspace/fake/objspace.py
@@ -314,6 +314,9 @@
         t = TranslationContext(config=config)
         self.t = t     # for debugging
         ann = t.buildannotator()
+        def _do_startup():
+            self.threadlocals.enter_thread(self)
+        ann.build_types(_do_startup, [], complete_now=False)
         if func is not None:
             ann.build_types(func, argtypes, complete_now=False)
         if seeobj_w:
diff --git a/rpython/config/translationoption.py b/rpython/config/translationoption.py
--- a/rpython/config/translationoption.py
+++ b/rpython/config/translationoption.py
@@ -22,6 +22,12 @@
 
 IS_64_BITS = sys.maxint > 2147483647
 
+SUPPORT__THREAD = (    # whether the particular C compiler supports __thread
+    sys.platform.startswith("linux"))     # Linux works
+    # OS/X doesn't work, because we still target 10.5/10.6 and the
+    # minimum required version is 10.7.  Windows doesn't work.  Please
+    # add other platforms here if it works on them.
+
 MAINDIR = os.path.dirname(os.path.dirname(__file__))
 CACHE_DIR = os.path.realpath(os.path.join(MAINDIR, '_cache'))
 
@@ -156,7 +162,8 @@
     # portability options
     BoolOption("no__thread",
                "don't use __thread for implementing TLS",
-               default=False, cmdline="--no__thread", negation=False),
+               default=not SUPPORT__THREAD, cmdline="--no__thread",
+               negation=False),
     IntOption("make_jobs", "Specify -j argument to make for compilation"
               " (C backend only)",
               cmdline="--make-jobs", default=detect_number_of_processors()),
diff --git a/rpython/jit/backend/llsupport/test/ztranslation_test.py b/rpython/jit/backend/llsupport/test/ztranslation_test.py
--- a/rpython/jit/backend/llsupport/test/ztranslation_test.py
+++ b/rpython/jit/backend/llsupport/test/ztranslation_test.py
@@ -4,6 +4,8 @@
 from rpython.rlib.jit import PARAMETERS, dont_look_inside
 from rpython.rlib.jit import promote
 from rpython.rlib import jit_hooks
+from rpython.rlib.objectmodel import keepalive_until_here
+from rpython.rlib.rthread import ThreadLocalReference
 from rpython.jit.backend.detect_cpu import getcpuclass
 from rpython.jit.backend.test.support import CCompiledMixin
 from rpython.jit.codewriter.policy import StopAtXPolicy
@@ -21,6 +23,7 @@
         # - profiler
         # - full optimizer
         # - floats neg and abs
+        # - threadlocalref_get
 
         class Frame(object):
             _virtualizable_ = ['i']
@@ -28,6 +31,10 @@
             def __init__(self, i):
                 self.i = i
 
+        class Foo(object):
+            pass
+        t = ThreadLocalReference(Foo)
+
         @dont_look_inside
         def myabs(x):
             return abs(x)
@@ -56,6 +63,7 @@
                 k = myabs(j)
                 if k - abs(j):  raise ValueError
                 if k - abs(-j): raise ValueError
+                if t.get().nine != 9: raise ValueError
             return chr(total % 253)
         #
         from rpython.rtyper.lltypesystem import lltype, rffi
@@ -78,8 +86,12 @@
             return res
         #
         def main(i, j):
+            foo = Foo()
+            foo.nine = -(i + j)
+            t.set(foo)
             a_char = f(i, j)
             a_float = libffi_stuff(i, j)
+            keepalive_until_here(foo)
             return ord(a_char) * 10 + int(a_float)
         expected = main(40, -49)
         res = self.meta_interp(main, [40, -49])
diff --git a/rpython/jit/backend/x86/assembler.py b/rpython/jit/backend/x86/assembler.py
--- a/rpython/jit/backend/x86/assembler.py
+++ b/rpython/jit/backend/x86/assembler.py
@@ -2351,10 +2351,29 @@
         assert isinstance(reg, RegLoc)
         self.mc.MOV_rr(reg.value, ebp.value)
 
+    def threadlocalref_get(self, op, resloc):
+        # this function is only called on Linux
+        from rpython.jit.codewriter.jitcode import ThreadLocalRefDescr
+        from rpython.jit.backend.x86 import stmtlocal
+        assert isinstance(resloc, RegLoc)
+        effectinfo = op.getdescr().get_extra_info()
+        assert effectinfo.extradescrs is not None
+        ed = effectinfo.extradescrs[0]
+        assert isinstance(ed, ThreadLocalRefDescr)
+        addr1 = rffi.cast(lltype.Signed, ed.get_tlref_addr())
+        addr0 = stmtlocal.threadlocal_base()
+        addr = addr1 - addr0
+        assert rx86.fits_in_32bits(addr)
+        mc = self.mc
+        mc.writechar(stmtlocal.SEGMENT_TL)     # prefix
+        mc.MOV_rj(resloc.value, addr)
+
+
 genop_discard_list = [Assembler386.not_implemented_op_discard] * rop._LAST
 genop_list = [Assembler386.not_implemented_op] * rop._LAST
 genop_llong_list = {}
 genop_math_list = {}
+genop_tlref_list = {}
 genop_guard_list = [Assembler386.not_implemented_op_guard] * rop._LAST
 
 for name, value in Assembler386.__dict__.iteritems():
diff --git a/rpython/jit/backend/x86/regalloc.py b/rpython/jit/backend/x86/regalloc.py
--- a/rpython/jit/backend/x86/regalloc.py
+++ b/rpython/jit/backend/x86/regalloc.py
@@ -2,7 +2,7 @@
 """ Register allocation scheme.
 """
 
-import os
+import os, sys
 from rpython.jit.backend.llsupport import symbolic
 from rpython.jit.backend.llsupport.descr import (ArrayDescr, CallDescr,
     unpack_arraydescr, unpack_fielddescr, unpack_interiorfielddescr)
@@ -692,6 +692,15 @@
         loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(1))
         self.perform_math(op, [loc0], loc0)
 
+    TLREF_SUPPORT = sys.platform.startswith('linux')
+
+    def _consider_threadlocalref_get(self, op):
+        if self.TLREF_SUPPORT:
+            resloc = self.force_allocate_reg(op.result)
+            self.assembler.threadlocalref_get(op, resloc)
+        else:
+            self._consider_call(op)
+
     def _call(self, op, arglocs, force_store=[], guard_not_forced_op=None):
         # we need to save registers on the stack:
         #
@@ -769,6 +778,8 @@
                         return
             if oopspecindex == EffectInfo.OS_MATH_SQRT:
                 return self._consider_math_sqrt(op)
+            if oopspecindex == EffectInfo.OS_THREADLOCALREF_GET:
+                return self._consider_threadlocalref_get(op)
         self._consider_call(op)
 
     def consider_call_may_force(self, op, guard_op):
diff --git a/rpython/jit/backend/x86/stmtlocal.py b/rpython/jit/backend/x86/stmtlocal.py
new file mode 100644
--- /dev/null
+++ b/rpython/jit/backend/x86/stmtlocal.py
@@ -0,0 +1,32 @@
+from rpython.rtyper.lltypesystem import lltype, rffi
+from rpython.translator.tool.cbuild import ExternalCompilationInfo
+from rpython.jit.backend.x86.arch import WORD
+
+SEGMENT_FS = '\x64'
+SEGMENT_GS = '\x65'
+
+if WORD == 4:
+    SEGMENT_TL = SEGMENT_GS
+    _instruction = "movl %%gs:0, %0"
+else:
+    SEGMENT_TL = SEGMENT_FS
+    _instruction = "movq %%fs:0, %0"
+
+eci = ExternalCompilationInfo(post_include_bits=['''
+#define RPY_STM_JIT  1
+static long pypy__threadlocal_base(void)
+{
+    /* XXX ONLY LINUX WITH GCC/CLANG FOR NOW XXX */
+    long result;
+    asm("%s" : "=r"(result));
+    return result;
+}
+''' % _instruction])
+
+
+threadlocal_base = rffi.llexternal(
+    'pypy__threadlocal_base',
+    [], lltype.Signed,
+    compilation_info=eci,
+    _nowrapper=True,
+    ) #transactionsafe=True)
diff --git a/rpython/jit/codewriter/effectinfo.py b/rpython/jit/codewriter/effectinfo.py
--- a/rpython/jit/codewriter/effectinfo.py
+++ b/rpython/jit/codewriter/effectinfo.py
@@ -22,6 +22,7 @@
     OS_STR2UNICODE              = 2    # "str.str2unicode"
     OS_SHRINK_ARRAY             = 3    # rgc.ll_shrink_array
     OS_DICT_LOOKUP              = 4    # ll_dict_lookup
+    OS_THREADLOCALREF_GET       = 5    # llop.threadlocalref_get
     #
     OS_STR_CONCAT               = 22   # "stroruni.concat"
     OS_STR_SLICE                = 23   # "stroruni.slice"
diff --git a/rpython/jit/codewriter/jitcode.py b/rpython/jit/codewriter/jitcode.py
--- a/rpython/jit/codewriter/jitcode.py
+++ b/rpython/jit/codewriter/jitcode.py
@@ -117,6 +117,26 @@
         raise NotImplementedError
 
 
+class ThreadLocalRefDescr(AbstractDescr):
+    # A special descr used as the extradescr in a call to a
+    # threadlocalref_get function.  If the backend supports it,
+    # it can use this 'get_tlref_addr()' to get the address *in the
+    # current thread* of the thread-local variable.  If, on the current
+    # platform, the "__thread" variables are implemented as an offset
+    # from some base register (e.g. %fs on x86-64), then the backend will
+    # immediately substract the current value of the base register.
+    # This gives an offset from the base register, and this can be
+    # written down in an assembler instruction to load the "__thread"
+    # variable from anywhere.
+
+    def __init__(self, opaque_id):
+        from rpython.rtyper.lltypesystem.lloperation import llop
+        from rpython.rtyper.lltypesystem import llmemory
+        def get_tlref_addr():
+            return llop.threadlocalref_getaddr(llmemory.Address, opaque_id)
+        self.get_tlref_addr = get_tlref_addr
+
+
 class LiveVarsInfo(object):
     def __init__(self, live_i, live_r, live_f):
         self.live_i = live_i
diff --git a/rpython/jit/codewriter/jtransform.py b/rpython/jit/codewriter/jtransform.py
--- a/rpython/jit/codewriter/jtransform.py
+++ b/rpython/jit/codewriter/jtransform.py
@@ -390,11 +390,15 @@
         lst.append(v)
 
     def handle_residual_call(self, op, extraargs=[], may_call_jitcodes=False,
-                             oopspecindex=EffectInfo.OS_NONE):
+                             oopspecindex=EffectInfo.OS_NONE,
+                             extraeffect=None,
+                             extradescr=None):
         """A direct_call turns into the operation 'residual_call_xxx' if it
         is calling a function that we don't want to JIT.  The initial args
         of 'residual_call_xxx' are the function to call, and its calldescr."""
-        calldescr = self.callcontrol.getcalldescr(op, oopspecindex=oopspecindex)
+        calldescr = self.callcontrol.getcalldescr(op, oopspecindex=oopspecindex,
+                                                  extraeffect=extraeffect,
+                                                  extradescr=extradescr)
         op1 = self.rewrite_call(op, 'residual_call',
                                 [op.args[0]] + extraargs, calldescr=calldescr)
         if may_call_jitcodes or self.callcontrol.calldescr_canraise(calldescr):
@@ -1903,6 +1907,18 @@
                              None)
         return [op0, op1]
 
+    def rewrite_op_threadlocalref_get(self, op):
+        from rpython.jit.codewriter.jitcode import ThreadLocalRefDescr
+        opaqueid = op.args[0].value
+        op1 = self.prepare_builtin_call(op, 'threadlocalref_getter', [],
+                                        extra=(opaqueid,),
+                                        extrakey=opaqueid._obj)
+        extradescr = ThreadLocalRefDescr(opaqueid)
+        return self.handle_residual_call(op1,
+            oopspecindex=EffectInfo.OS_THREADLOCALREF_GET,
+            extraeffect=EffectInfo.EF_LOOPINVARIANT,
+            extradescr=[extradescr])
+
 # ____________________________________________________________
 
 class NotSupported(Exception):
diff --git a/rpython/jit/codewriter/support.py b/rpython/jit/codewriter/support.py
--- a/rpython/jit/codewriter/support.py
+++ b/rpython/jit/codewriter/support.py
@@ -712,6 +712,11 @@
     build_ll_1_raw_free_no_track_allocation = (
         build_raw_free_builder(track_allocation=False))
 
+    def build_ll_0_threadlocalref_getter(opaqueid):
+        def _ll_0_threadlocalref_getter():
+            return llop.threadlocalref_get(rclass.OBJECTPTR, opaqueid)
+        return _ll_0_threadlocalref_getter
+
     def _ll_1_weakref_create(obj):
         return llop.weakref_create(llmemory.WeakRefPtr, obj)
 
diff --git a/rpython/jit/codewriter/test/test_jtransform.py b/rpython/jit/codewriter/test/test_jtransform.py
--- a/rpython/jit/codewriter/test/test_jtransform.py
+++ b/rpython/jit/codewriter/test/test_jtransform.py
@@ -147,6 +147,7 @@
              EI.OS_UNIEQ_LENGTHOK:       ([PUNICODE, PUNICODE], INT),
              EI.OS_RAW_MALLOC_VARSIZE_CHAR: ([INT], ARRAYPTR),
              EI.OS_RAW_FREE:             ([ARRAYPTR], lltype.Void),
+             EI.OS_THREADLOCALREF_GET:   ([], rclass.OBJECTPTR),
             }
             argtypes = argtypes[oopspecindex]
             assert argtypes[0] == [v.concretetype for v in op.args[1:]]
@@ -157,6 +158,8 @@
                 assert extraeffect == EI.EF_CAN_RAISE
             elif oopspecindex == EI.OS_RAW_FREE:
                 assert extraeffect == EI.EF_CANNOT_RAISE
+            elif oopspecindex == EI.OS_THREADLOCALREF_GET:
+                assert extraeffect == EI.EF_LOOPINVARIANT
             else:
                 assert extraeffect == EI.EF_ELIDABLE_CANNOT_RAISE
         return 'calldescr-%d' % oopspecindex
@@ -1300,6 +1303,23 @@
     assert op1.result is None
     assert op2 is None
 
+def test_threadlocalref_get():
+    from rpython.rtyper.lltypesystem import rclass
+    from rpython.rlib.rthread import ThreadLocalReference
+    OS_THREADLOCALREF_GET = effectinfo.EffectInfo.OS_THREADLOCALREF_GET
+    class Foo: pass
+    t = ThreadLocalReference(Foo)
+    v2 = varoftype(rclass.OBJECTPTR)
+    c_opaqueid = const(t.opaque_id)
+    op = SpaceOperation('threadlocalref_get', [c_opaqueid], v2)
+    tr = Transformer(FakeCPU(), FakeBuiltinCallControl())
+    op0 = tr.rewrite_operation(op)
+    assert op0.opname == 'residual_call_r_r'
+    assert op0.args[0].value == 'threadlocalref_getter' # pseudo-function as str
+    assert op0.args[1] == ListOfKind("ref", [])
+    assert op0.args[2] == 'calldescr-%d' % OS_THREADLOCALREF_GET
+    assert op0.result == v2
+
 def test_unknown_operation():
     op = SpaceOperation('foobar', [], varoftype(lltype.Void))
     tr = Transformer()
diff --git a/rpython/jit/metainterp/test/test_threadlocal.py b/rpython/jit/metainterp/test/test_threadlocal.py
new file mode 100644
--- /dev/null
+++ b/rpython/jit/metainterp/test/test_threadlocal.py
@@ -0,0 +1,30 @@
+import py
+from rpython.jit.metainterp.test.support import LLJitMixin
+from rpython.rlib.rthread import ThreadLocalReference
+from rpython.rlib.jit import dont_look_inside
+
+
+class ThreadLocalTest(object):
+
+    def test_threadlocalref_get(self):
+        class Foo:
+            pass
+        t = ThreadLocalReference(Foo)
+        x = Foo()
+
+        @dont_look_inside
+        def setup():
+            t.set(x)
+
+        def f():
+            setup()
+            if t.get() is x:
+                return 42
+            return -666
+
+        res = self.interp_operations(f, [])
+        assert res == 42
+
+
+class TestLLtype(ThreadLocalTest, LLJitMixin):
+    pass
diff --git a/rpython/rlib/rthread.py b/rpython/rlib/rthread.py
--- a/rpython/rlib/rthread.py
+++ b/rpython/rlib/rthread.py
@@ -272,3 +272,65 @@
         llop.gc_thread_after_fork(lltype.Void, result_of_fork, opaqueaddr)
     else:
         assert opaqueaddr == llmemory.NULL
+
+# ____________________________________________________________
+#
+# Thread-locals.  Only for references that change "not too often" --
+# for now, the JIT compiles get() as a loop-invariant, so basically
+# don't change them.
+# KEEP THE REFERENCE ALIVE, THE GC DOES NOT FOLLOW THEM SO FAR!
+# We use _make_sure_does_not_move() to make sure the pointer will not move.
+
+ecitl = ExternalCompilationInfo(
+    includes = ['src/threadlocal.h'],
+    separate_module_files = [translator_c_dir / 'src' / 'threadlocal.c'])
+ensure_threadlocal = rffi.llexternal_use_eci(ecitl)
+
+class ThreadLocalReference(object):
+    _COUNT = 1
+    OPAQUEID = lltype.OpaqueType("ThreadLocalRef",
+                                 hints={"threadlocalref": True,
+                                        "external": "C",
+                                        "c_name": "RPyThreadStaticTLS"})
+
+    def __init__(self, Cls):
+        "NOT_RPYTHON: must be prebuilt"
+        import thread
+        self.Cls = Cls
+        self.local = thread._local()      # <- NOT_RPYTHON
+        unique_id = ThreadLocalReference._COUNT
+        ThreadLocalReference._COUNT += 1
+        opaque_id = lltype.opaqueptr(ThreadLocalReference.OPAQUEID,
+                                     'tlref%d' % unique_id)
+        self.opaque_id = opaque_id
+
+        def get():
+            if we_are_translated():
+                from rpython.rtyper.lltypesystem import rclass
+                from rpython.rtyper.annlowlevel import cast_base_ptr_to_instance
+                ptr = llop.threadlocalref_get(rclass.OBJECTPTR, opaque_id)
+                return cast_base_ptr_to_instance(Cls, ptr)
+            else:
+                return getattr(self.local, 'value', None)
+
+        @jit.dont_look_inside
+        def set(value):
+            assert isinstance(value, Cls) or value is None
+            if we_are_translated():
+                from rpython.rtyper.annlowlevel import cast_instance_to_base_ptr
+                from rpython.rlib.rgc import _make_sure_does_not_move
+                from rpython.rlib.objectmodel import running_on_llinterp
+                ptr = cast_instance_to_base_ptr(value)
+                if not running_on_llinterp:
+                    gcref = lltype.cast_opaque_ptr(llmemory.GCREF, ptr)
+                    _make_sure_does_not_move(gcref)
+                llop.threadlocalref_set(lltype.Void, opaque_id, ptr)
+                ensure_threadlocal()
+            else:
+                self.local.value = value
+
+        self.get = get
+        self.set = set
+
+    def _freeze_(self):
+        return True
diff --git a/rpython/rlib/test/test_rthread.py b/rpython/rlib/test/test_rthread.py
--- a/rpython/rlib/test/test_rthread.py
+++ b/rpython/rlib/test/test_rthread.py
@@ -1,4 +1,4 @@
-import gc
+import gc, time
 from rpython.rlib.rthread import *
 from rpython.translator.c.test.test_boehm import AbstractGCTestClass
 from rpython.rtyper.lltypesystem import lltype, rffi
@@ -29,6 +29,23 @@
     else:
         py.test.fail("Did not raise")
 
+def test_tlref_untranslated():
+    class FooBar(object):
+        pass
+    t = ThreadLocalReference(FooBar)
+    results = []
+    def subthread():
+        x = FooBar()
+        results.append(t.get() is None)
+        t.set(x)
+        results.append(t.get() is x)
+        time.sleep(0.2)
+        results.append(t.get() is x)
+    for i in range(5):
+        start_new_thread(subthread, ())
+    time.sleep(0.5)
+    assert results == [True] * 15
+
 
 class AbstractThreadTests(AbstractGCTestClass):
     use_threads = True
@@ -198,6 +215,20 @@
         res = fn()
         assert res >= 0.95
 
+    def test_tlref(self):
+        class FooBar(object):
+            pass
+        t = ThreadLocalReference(FooBar)
+        def f():
+            x1 = FooBar()
+            t.set(x1)
+            import gc; gc.collect()
+            assert t.get() is x1
+            return 42
+        fn = self.getcompiled(f, [])
+        res = fn()
+        assert res == 42
+
 #class TestRunDirectly(AbstractThreadTests):
 #    def getcompiled(self, f, argtypes):
 #        return f
@@ -208,4 +239,4 @@
     gcpolicy = 'boehm'
 
 class TestUsingFramework(AbstractThreadTests):
-    gcpolicy = 'generation'
+    gcpolicy = 'minimark'
diff --git a/rpython/rtyper/llinterp.py b/rpython/rtyper/llinterp.py
--- a/rpython/rtyper/llinterp.py
+++ b/rpython/rtyper/llinterp.py
@@ -919,6 +919,20 @@
     def op_stack_current(self):
         return 0
 
+    def op_threadlocalref_set(self, key, value):
+        try:
+            d = self.llinterpreter.tlrefsdict
+        except AttributeError:
+            d = self.llinterpreter.tlrefsdict = {}
+        d[key._obj] = value
+
+    def op_threadlocalref_get(self, key):
+        d = self.llinterpreter.tlrefsdict
+        return d[key._obj]
+
+    def op_threadlocalref_getaddr(self, key):
+        raise NotImplementedError("threadlocalref_getaddr")
+
     # __________________________________________________________
     # operations on addresses
 
diff --git a/rpython/rtyper/lltypesystem/lloperation.py b/rpython/rtyper/lltypesystem/lloperation.py
--- a/rpython/rtyper/lltypesystem/lloperation.py
+++ b/rpython/rtyper/lltypesystem/lloperation.py
@@ -541,6 +541,10 @@
     'getslice':             LLOp(canraise=(Exception,)),
     'check_and_clear_exc':  LLOp(),
 
+    'threadlocalref_get':   LLOp(sideeffects=False),
+    'threadlocalref_getaddr': LLOp(sideeffects=False),
+    'threadlocalref_set':   LLOp(),
+
     # __________ debugging __________
     'debug_view':           LLOp(),
     'debug_print':          LLOp(canrun=True),
diff --git a/rpython/translator/c/node.py b/rpython/translator/c/node.py
--- a/rpython/translator/c/node.py
+++ b/rpython/translator/c/node.py
@@ -959,12 +959,30 @@
                 args.append('0')
         yield 'RPyOpaque_SETUP_%s(%s);' % (T.tag, ', '.join(args))
 
+class ThreadLocalRefOpaqueNode(ContainerNode):
+    nodekind = 'tlrefopaque'
+
+    def basename(self):
+        return self.obj._name
+
+    def enum_dependencies(self):
+        return []
+
+    def initializationexpr(self, decoration=''):
+        return ['0']
+
+    def startupcode(self):
+        p = self.getptrname()
+        yield 'RPyThreadStaticTLS_Create(%s);' % (p,)
+
 
 def opaquenode_factory(db, T, obj):
     if T == RuntimeTypeInfo:
         return db.gcpolicy.rtti_node_factory()(db, T, obj)
     if T.hints.get("render_structure", False):
         return ExtType_OpaqueNode(db, T, obj)
+    if T.hints.get("threadlocalref", False):
+        return ThreadLocalRefOpaqueNode(db, T, obj)
     raise Exception("don't know about %r" % (T,))
 
 
diff --git a/rpython/translator/c/src/g_prerequisite.h b/rpython/translator/c/src/g_prerequisite.h
--- a/rpython/translator/c/src/g_prerequisite.h
+++ b/rpython/translator/c/src/g_prerequisite.h
@@ -23,3 +23,6 @@
 # define RPY_LENGTH0     1       /* array decl [0] are bad */
 # define RPY_DUMMY_VARLENGTH     /* nothing */
 #endif
+
+
+#include "src/threadlocal.h"
diff --git a/rpython/translator/c/src/stack.c b/rpython/translator/c/src/stack.c
--- a/rpython/translator/c/src/stack.c
+++ b/rpython/translator/c/src/stack.c
@@ -32,12 +32,7 @@
 		/* XXX We assume that initialization is performed early,
 		   when there is still only one thread running.  This
 		   allows us to ignore race conditions here */
-		char *errmsg = RPyThreadStaticTLS_Create(&end_tls_key);
-		if (errmsg) {
-			/* XXX should we exit the process? */
-			fprintf(stderr, "Internal PyPy error: %s\n", errmsg);
-			return 1;
-		}
+		RPyThreadStaticTLS_Create(&end_tls_key);
 	}
 
 	baseptr = (char *) RPyThreadStaticTLS_Get(end_tls_key);
diff --git a/rpython/translator/c/src/threadlocal.c b/rpython/translator/c/src/threadlocal.c
--- a/rpython/translator/c/src/threadlocal.c
+++ b/rpython/translator/c/src/threadlocal.c
@@ -1,24 +1,28 @@
+#include <stdio.h>
+#include <stdlib.h>
 #include "src/threadlocal.h"
 
 #ifdef _WIN32
 
-char *RPyThreadTLS_Create(RPyThreadTLS *result)
+void RPyThreadTLS_Create(RPyThreadTLS *result)
 {
     *result = TlsAlloc();
-    if (*result == TLS_OUT_OF_INDEXES)
-        return "out of thread-local storage indexes";
-    else
-        return NULL;
+    if (*result == TLS_OUT_OF_INDEXES) {
+        fprintf(stderr, "Internal RPython error: "
+                        "out of thread-local storage indexes");
+        abort();
+    }
 }
 
 #else
 
-char *RPyThreadTLS_Create(RPyThreadTLS *result)
+void RPyThreadTLS_Create(RPyThreadTLS *result)
 {
-    if (pthread_key_create(result, NULL) != 0)
-        return "out of thread-local storage keys";
-    else
-        return NULL;
+    if (pthread_key_create(result, NULL) != 0) {
+        fprintf(stderr, "Internal RPython error: "
+                        "out of thread-local storage keys");
+        abort();
+    }
 }
 
 #endif
diff --git a/rpython/translator/c/src/threadlocal.h b/rpython/translator/c/src/threadlocal.h
--- a/rpython/translator/c/src/threadlocal.h
+++ b/rpython/translator/c/src/threadlocal.h
@@ -1,4 +1,7 @@
 /* Thread-local storage */
+#ifndef _SRC_THREADLOCAL_H
+#define _SRC_THREADLOCAL_H
+
 
 #ifdef _WIN32
 
@@ -22,9 +25,10 @@
 #ifdef USE___THREAD
 
 #define RPyThreadStaticTLS                  __thread void *
-#define RPyThreadStaticTLS_Create(tls)      NULL
+#define RPyThreadStaticTLS_Create(tls)      (void)0
 #define RPyThreadStaticTLS_Get(tls)         tls
 #define RPyThreadStaticTLS_Set(tls, value)  tls = value
+#define OP_THREADLOCALREF_GETADDR(tlref, ptr)  ptr = tlref
 
 #endif
 
@@ -34,7 +38,13 @@
 #define RPyThreadStaticTLS_Create(key) RPyThreadTLS_Create(key)
 #define RPyThreadStaticTLS_Get(key)    RPyThreadTLS_Get(key)
 #define RPyThreadStaticTLS_Set(key, value) RPyThreadTLS_Set(key, value)
-char *RPyThreadTLS_Create(RPyThreadTLS *result);
+void RPyThreadTLS_Create(RPyThreadTLS *result);
 
 #endif
 
+
+#define OP_THREADLOCALREF_SET(tlref, ptr, _) RPyThreadStaticTLS_Set(*tlref, ptr)
+#define OP_THREADLOCALREF_GET(tlref, ptr)   ptr = RPyThreadStaticTLS_Get(*tlref)
+
+
+#endif /* _SRC_THREADLOCAL_H */


More information about the pypy-commit mailing list