[pypy-commit] pypy stm-gc: End-of-transaction collections.

arigo noreply at buildbot.pypy.org
Sat Feb 4 13:23:53 CET 2012


Author: Armin Rigo <arigo at tunes.org>
Branch: stm-gc
Changeset: r52081:94c8bc20d9e0
Date: 2012-02-03 18:07 +0100
http://bitbucket.org/pypy/pypy/changeset/94c8bc20d9e0/

Log:	End-of-transaction collections.

diff --git a/pypy/rpython/memory/gc/stmgc.py b/pypy/rpython/memory/gc/stmgc.py
--- a/pypy/rpython/memory/gc/stmgc.py
+++ b/pypy/rpython/memory/gc/stmgc.py
@@ -3,7 +3,8 @@
 from pypy.rpython.lltypesystem.llmemory import raw_malloc_usage
 from pypy.rpython.memory.gc.base import GCBase
 from pypy.rlib.rarithmetic import LONG_BIT
-from pypy.rlib.debug import ll_assert
+from pypy.rlib.debug import ll_assert, debug_start, debug_stop
+from pypy.module.thread import ll_thread
 
 
 WORD = LONG_BIT // 8
@@ -37,7 +38,7 @@
     malloc_zero_filled = True    # xxx?
 
     HDR = lltype.Struct('header', ('tid', lltype.Signed),
-                                  ('version', lltype.Signed))
+                                  ('version', llmemory.Address))
     typeid_is_in_field = 'tid'
     withhash_flag_is_in_field = 'tid', 'XXX'
 
@@ -45,14 +46,16 @@
                                    ('nursery_top', llmemory.Address),
                                    ('nursery_start', llmemory.Address),
                                    ('nursery_size', lltype.Signed),
-                                   ('malloc_flags', lltype.Signed))
-
+                                   ('malloc_flags', lltype.Signed),
+                                   ('pending_list', llmemory.Address),
+                          )
 
     def __init__(self, config, stm_operations,
                  max_nursery_size=1024,
                  **kwds):
         GCBase.__init__(self, config, **kwds)
         self.stm_operations = stm_operations
+        self.collector = Collector(self)
         self.max_nursery_size = max_nursery_size
         #
         self.declare_readers()
@@ -62,6 +65,7 @@
         """Called at run-time to initialize the GC."""
         GCBase.setup(self)
         self.main_thread_tls = self.setup_thread(True)
+        self.mutex_lock = ll_thread.allocate_ll_lock()
 
     def _alloc_nursery(self):
         nursery = llarena.arena_malloc(self.max_nursery_size, 1)
@@ -73,6 +77,8 @@
         llarena.arena_free(nursery)
 
     def setup_thread(self, in_main_thread):
+        """Setup a thread.  Allocates the thread-local data structures.
+        Must be called only once per OS-level thread."""
         tls = lltype.malloc(self.GCTLS, flavor='raw')
         self.stm_operations.set_tls(self, llmemory.cast_ptr_to_adr(tls))
         tls.nursery_start = self._alloc_nursery()
@@ -90,19 +96,25 @@
             tls.malloc_flags = 0
         return tls
 
+    @staticmethod
+    def reset_nursery(tls):
+        """Clear and forget all locally allocated objects."""
+        size = tls.nursery_free - tls.nursery_start
+        llarena.arena_reset(tls.nursery_start, size, 2)
+        tls.nursery_free = tls.nursery_start
+
     def teardown_thread(self):
-        tls = self.get_tls()
+        """Teardown a thread.  Call this just before the OS-level thread
+        disappears."""
+        tls = self.collector.get_tls()
         self.stm_operations.set_tls(self, NULL)
         self._free_nursery(tls.nursery_start)
         lltype.free(tls, flavor='raw')
 
-    @always_inline
-    def get_tls(self):
-        tls = self.stm_operations.get_tls()
-        return llmemory.cast_adr_to_ptr(tls, lltype.Ptr(self.GCTLS))
+    # ----------
 
     def allocate_bump_pointer(self, size):
-        return self._allocate_bump_pointer(self.get_tls(), size)
+        return self._allocate_bump_pointer(self.collector.get_tls(), size)
 
     @always_inline
     def _allocate_bump_pointer(self, tls, size):
@@ -129,7 +141,7 @@
         # Check the mode: either in a transactional thread, or in
         # the main thread.  For now we do the same thing in both
         # modes, but set different flags.
-        tls = self.get_tls()
+        tls = self.collector.get_tls()
         flags = tls.malloc_flags
         #
         # Get the memory from the nursery.
@@ -145,12 +157,12 @@
         return llmemory.cast_adr_to_ptr(obj, llmemory.GCREF)
 
 
-    def _malloc_local_raw(self, size):
+    def _malloc_local_raw(self, tls, size):
         # for _stm_write_barrier_global(): a version of malloc that does
         # no initialization of the malloc'ed object
         size_gc_header = self.gcheaderbuilder.size_gc_header
         totalsize = size_gc_header + size
-        result = self.allocate_bump_pointer(totalsize)
+        result = self._allocate_bump_pointer(tls, totalsize)
         llarena.arena_reserve(result, totalsize)
         obj = result + size_gc_header
         return obj
@@ -229,8 +241,9 @@
             #
             # Here, we need to really make a local copy
             size = self.get_size(obj)
+            tls = self.collector.get_tls()
             try:
-                localobj = self._malloc_local_raw(size)
+                localobj = self._malloc_local_raw(tls, size)
             except MemoryError:
                 # XXX
                 fatalerror("MemoryError in _stm_write_barrier_global -- sorry")
@@ -252,7 +265,189 @@
             # Remove the GCFLAG_GLOBAL from the copy
             localhdr.tid &= ~GCFLAG_GLOBAL
             #
+            # Set the 'version' field of the local copy to be a pointer
+            # to the global obj.  (The field is called 'version' because
+            # of its use by the C STM library: on global objects (only),
+            # it is a version number.)
+            localhdr.version = obj
+            #
             # Register the object as a valid copy
             stm_operations.tldict_add(obj, localobj)
             #
             return localobj
+
+    # ----------
+
+    def acquire(self, lock):
+        ll_thread.c_thread_acquirelock(lock, 1)
+
+    def release(self, lock):
+        ll_thread.c_thread_releaselock(lock)
+
+
+# ------------------------------------------------------------
+
+
+class Collector(object):
+    """A separate frozen class.  Useful to prevent any buggy concurrent
+    access to GC data.  The methods here use the GCTLS instead for
+    storing things in a thread-local way."""
+
+    def __init__(self, gc):
+        self.gc = gc
+        self.stm_operations = gc.stm_operations
+
+    def _freeze_(self):
+        return True
+
+    def get_tls(self):
+        tls = self.stm_operations.get_tls()
+        return llmemory.cast_adr_to_ptr(tls, lltype.Ptr(StmGC.GCTLS))
+
+    def is_in_nursery(self, tls, addr):
+        ll_assert(llmemory.cast_adr_to_int(addr) & 1 == 0,
+                  "odd-valued (i.e. tagged) pointer unexpected here")
+        return tls.nursery_start <= addr < tls.nursery_top
+
+    def header(self, obj):
+        return self.gc.header(obj)
+
+
+    def start_transaction(self):
+        """Start a transaction, by clearing and resetting the tls nursery."""
+        tls = self.get_tls()
+        self.gc.reset_nursery(tls)
+
+
+    def commit_transaction(self):
+        """End of a transaction, just before its end.  No more GC
+        operations should occur afterwards!  Note that the C code that
+        does the commit runs afterwards, and may still abort."""
+        #
+        debug_start("gc-collect-commit")
+        #
+        tls = self.get_tls()
+        #
+        # Do a mark-and-move minor collection out of the tls' nursery
+        # into the main thread's global area (which is right now also
+        # called a nursery).  To simplify things, we use a global lock
+        # around the whole mark-and-move.
+        self.gc.acquire(self.gc.mutex_lock)
+        #
+        # We are starting from the tldict's local objects as roots.  At
+        # this point, these objects have GCFLAG_WAS_COPIED, and the other
+        # local objects don't.  We want to move all reachable local objects
+        # to the global area.
+        #
+        # Start from tracing the root objects
+        self.collect_roots_from_tldict(tls)
+        #
+        # Continue iteratively until we have reached all the reachable
+        # local objects
+        self.collect_from_pending_list(tls)
+        #
+        self.gc.release(self.gc.mutex_lock)
+        #
+        # Now, all indirectly reachable local objects have been copied into
+        # the global area, and all pointers have been fixed to point to the
+        # global copies, including in the local copy of the roots.  What
+        # remains is only overwriting of the global copy of the roots.
+        # This is done by the C code.
+        debug_stop("gc-collect-commit")
+
+
+    def collect_roots_from_tldict(self, tls):
+        tls.pending_list = NULL
+        # Enumerate the roots, which are the local copies of global objects.
+        # For each root, trace it.
+        self.stm_operations.enum_tldict_start()
+        while self.stm_operations.enum_tldict_find_next():
+            globalobj = self.stm_operations.enum_tldict_globalobj()
+            localobj = self.stm_operations.enum_tldict_localobj()
+            #
+            localhdr = self.header(localobj)
+            ll_assert(localhdr.version == globalobj,
+                      "in a root: localobj.version != globalobj")
+            ll_assert(localhdr.tid & GCFLAG_GLOBAL == 0,
+                      "in a root: unexpected GCFLAG_GLOBAL")
+            ll_assert(localhdr.tid & GCFLAG_WAS_COPIED != 0,
+                      "in a root: missing GCFLAG_WAS_COPIED")
+            #
+            self.trace_and_drag_out_of_nursery(tls, localobj)
+
+
+    def collect_from_pending_list(self, tls):
+        while tls.pending_list != NULL:
+            pending_obj = tls.pending_list
+            pending_hdr = self.header(pending_obj)
+            #
+            # 'pending_list' is a chained list of fresh global objects,
+            # linked together via their 'version' field.  The 'version'
+            # must be replaced with NULL after we pop the object from
+            # the linked list.
+            tls.pending_list = pending_hdr.version
+            pending_hdr.version = NULL
+            #
+            # Check the flags of pending_obj: it should be a fresh global
+            # object, without GCFLAG_WAS_COPIED
+            ll_assert(pending_hdr.tid & GCFLAG_GLOBAL != 0,
+                      "from pending list: missing GCFLAG_GLOBAL")
+            ll_assert(pending_hdr.tid & GCFLAG_WAS_COPIED == 0,
+                      "from pending list: unexpected GCFLAG_WAS_COPIED")
+            #
+            self.trace_and_drag_out_of_nursery(tls, pending_obj)
+
+
+    def trace_and_drag_out_of_nursery(self, tls, obj):
+        # This is called to fix the references inside 'obj', to ensure that
+        # they are global.  If necessary, the referenced objects are copied
+        # into the global area first.  This is called on the *local* copy of
+        # the roots, and on the fresh *global* copy of all other reached
+        # objects.
+        self.gc.trace(obj, self._trace_drag_out, tls)
+
+    def _trace_drag_out(self, root, tls):
+        obj = root.address[0]
+        hdr = self.header(obj)
+        #
+        # Figure out if the object is GLOBAL or not by looking at its
+        # address, not at its header --- to avoid cache misses and
+        # pollution for all global objects
+        if not self.is_in_nursery(tls, obj):
+            ll_assert(hdr.tid & GCFLAG_GLOBAL != 0,
+                      "trace_and_mark: non-GLOBAL obj is not in nursery")
+            return        # ignore global objects
+        #
+        ll_assert(hdr.tid & GCFLAG_GLOBAL == 0,
+                  "trace_and_mark: GLOBAL obj in nursery")
+        #
+        if hdr.tid & GCFLAG_WAS_COPIED != 0:
+            # this local object is a root or was already marked.  Either
+            # way, its 'version' field should point to the corresponding
+            # global object.
+            globalobj = hdr.version
+            #
+        else:
+            # First visit to a local-only 'obj': copy it into the global area
+            size = self.gc.get_size(obj)
+            main_tls = self.gc.main_thread_tls
+            globalobj = self.gc._malloc_local_raw(main_tls, size)
+            llmemory.raw_memcopy(obj, globalobj, size)
+            #
+            # Initialize the header of the 'globalobj'
+            globalhdr = self.header(globalobj)
+            globalhdr.tid = hdr.tid | GCFLAG_GLOBAL
+            #
+            # Add the flags to 'localobj' to say 'has been copied now'
+            hdr.tid |= GCFLAG_WAS_COPIED
+            hdr.version = globalobj
+            #
+            # Set a temporary linked list through the globalobj's version
+            # numbers.  This is normally not allowed, but it works here
+            # because these new globalobjs are not visible to any other
+            # thread before the commit is really complete.
+            globalhdr.version = tls.pending_list
+            tls.pending_list = globalobj
+        #
+        # Fix the original root.address[0] to point to the globalobj
+        root.address[0] = globalobj
diff --git a/pypy/rpython/memory/gc/test/test_stmgc.py b/pypy/rpython/memory/gc/test/test_stmgc.py
--- a/pypy/rpython/memory/gc/test/test_stmgc.py
+++ b/pypy/rpython/memory/gc/test/test_stmgc.py
@@ -3,11 +3,23 @@
 from pypy.rpython.memory.gc.stmgc import GCFLAG_GLOBAL, GCFLAG_WAS_COPIED
 
 
-S = lltype.GcStruct('S', ('a', lltype.Signed), ('b', lltype.Signed))
+S = lltype.GcStruct('S', ('a', lltype.Signed), ('b', lltype.Signed),
+                         ('c', lltype.Signed))
 ofs_a = llmemory.offsetof(S, 'a')
 
+SR = lltype.GcForwardReference()
+SR.become(lltype.GcStruct('SR', ('s1', lltype.Ptr(S)),
+                                ('sr2', lltype.Ptr(SR)),
+                                ('sr3', lltype.Ptr(SR))))
+
 
 class FakeStmOperations:
+    # The point of this class is to make sure about the distinction between
+    # RPython code in the GC versus C code in translator/stm/src_stm.  This
+    # class contains a fake implementation of what should be in C.  So almost
+    # any use of 'self._gc' is wrong here: it's stmgc.py that should call
+    # et.c, and not the other way around.
+
     threadnum = 0          # 0 = main thread; 1,2,3... = transactional threads
 
     def set_tls(self, gc, tls):
@@ -17,6 +29,7 @@
             assert not hasattr(self, '_gc')
             self._tls_dict = {0: tls}
             self._tldicts = {0: {}}
+            self._tldicts_iterators = {}
             self._gc = gc
             self._transactional_copies = []
         else:
@@ -39,6 +52,32 @@
         assert obj not in tldict
         tldict[obj] = localobj
 
+    def enum_tldict_start(self):
+        it = self._tldicts[self.threadnum].iteritems()
+        self._tldicts_iterators[self.threadnum] = [it, None, None]
+
+    def enum_tldict_find_next(self):
+        state = self._tldicts_iterators[self.threadnum]
+        try:
+            next_key, next_value = state[0].next()
+        except StopIteration:
+            state[1] = None
+            state[2] = None
+            return False
+        state[1] = next_key
+        state[2] = next_value
+        return True
+
+    def enum_tldict_globalobj(self):
+        state = self._tldicts_iterators[self.threadnum]
+        assert state[1] is not None
+        return state[1]
+
+    def enum_tldict_localobj(self):
+        state = self._tldicts_iterators[self.threadnum]
+        assert state[2] is not None
+        return state[2]
+
     class stm_read_word:
         def __init__(self, obj, offset):
             self.obj = obj
@@ -67,6 +106,21 @@
     else:
         assert 0
 
+def fake_trace(obj, callback, arg):
+    TYPE = obj.ptr._TYPE.TO
+    if TYPE == S:
+        ofslist = []     # no pointers in S
+    elif TYPE == SR:
+        ofslist = [llmemory.offsetof(SR, 's1'),
+                   llmemory.offsetof(SR, 'sr2'),
+                   llmemory.offsetof(SR, 'sr3')]
+    else:
+        assert 0
+    for ofs in ofslist:
+        addr = obj + ofs
+        if addr.address[0]:
+            callback(addr, arg)
+
 
 class TestBasic:
     GCClass = StmGC
@@ -78,6 +132,7 @@
                                translated_to_c=False)
         self.gc.DEBUG = True
         self.gc.get_size = fake_get_size
+        self.gc.trace = fake_trace
         self.gc.setup()
 
     def teardown_method(self, meth):
@@ -97,6 +152,9 @@
         self.gc.stm_operations.threadnum = threadnum
         if threadnum not in self.gc.stm_operations._tls_dict:
             self.gc.setup_thread(False)
+    def gcsize(self, S):
+        return (llmemory.raw_malloc_usage(llmemory.sizeof(self.gc.HDR)) +
+                llmemory.raw_malloc_usage(llmemory.sizeof(S)))
 
     def test_gc_creation_works(self):
         pass
@@ -193,3 +251,94 @@
         #
         u_adr = self.gc.write_barrier(u_adr)  # local object
         assert u_adr == t_adr
+
+    def test_commit_transaction_empty(self):
+        self.select_thread(1)
+        s, s_adr = self.malloc(S)
+        t, t_adr = self.malloc(S)
+        self.gc.collector.commit_transaction()    # no roots
+        main_tls = self.gc.main_thread_tls
+        assert main_tls.nursery_free == main_tls.nursery_start   # empty
+
+    def test_commit_transaction_no_references(self):
+        s, s_adr = self.malloc(S)
+        s.b = 12345
+        self.select_thread(1)
+        t_adr = self.gc.write_barrier(s_adr)   # make a local copy
+        t = llmemory.cast_adr_to_ptr(t_adr, lltype.Ptr(S))
+        assert s != t
+        assert self.gc.header(t_adr).version == s_adr
+        t.b = 67890
+        #
+        main_tls = self.gc.main_thread_tls
+        assert main_tls.nursery_free != main_tls.nursery_start  # contains s
+        old_value = main_tls.nursery_free
+        #
+        self.gc.collector.commit_transaction()
+        #
+        assert main_tls.nursery_free == old_value    # no new object
+        assert s.b == 12345     # not updated by the GC code
+        assert t.b == 67890     # still valid
+
+    def test_commit_transaction_with_one_reference(self):
+        sr, sr_adr = self.malloc(SR)
+        assert sr.s1 == lltype.nullptr(S)
+        assert sr.sr2 == lltype.nullptr(SR)
+        self.select_thread(1)
+        tr_adr = self.gc.write_barrier(sr_adr)   # make a local copy
+        tr = llmemory.cast_adr_to_ptr(tr_adr, lltype.Ptr(SR))
+        assert sr != tr
+        t, t_adr = self.malloc(S)
+        t.b = 67890
+        assert tr.s1 == lltype.nullptr(S)
+        assert tr.sr2 == lltype.nullptr(SR)
+        tr.s1 = t
+        #
+        main_tls = self.gc.main_thread_tls
+        old_value = main_tls.nursery_free
+        #
+        self.gc.collector.commit_transaction()
+        #
+        assert main_tls.nursery_free - old_value == self.gcsize(S)
+
+    def test_commit_transaction_with_graph(self):
+        sr1, sr1_adr = self.malloc(SR)
+        sr2, sr2_adr = self.malloc(SR)
+        self.select_thread(1)
+        tr1_adr = self.gc.write_barrier(sr1_adr)   # make a local copy
+        tr2_adr = self.gc.write_barrier(sr2_adr)   # make a local copy
+        tr1 = llmemory.cast_adr_to_ptr(tr1_adr, lltype.Ptr(SR))
+        tr2 = llmemory.cast_adr_to_ptr(tr2_adr, lltype.Ptr(SR))
+        tr3, tr3_adr = self.malloc(SR)
+        tr4, tr4_adr = self.malloc(SR)
+        t, t_adr = self.malloc(S)
+        #
+        tr1.sr2 = tr3; tr1.sr3 = tr1
+        tr2.sr2 = tr3; tr2.sr3 = tr3
+        tr3.sr2 = tr4; tr3.sr3 = tr2
+        tr4.sr2 = tr3; tr4.sr3 = tr3; tr4.s1 = t
+        #
+        for i in range(4):
+            self.malloc(S)     # forgotten
+        #
+        main_tls = self.gc.main_thread_tls
+        old_value = main_tls.nursery_free
+        #
+        self.gc.collector.commit_transaction()
+        #
+        assert main_tls.nursery_free - old_value == (
+            self.gcsize(SR) + self.gcsize(SR) + self.gcsize(S))
+        #
+        sr3_adr = self.gc.header(tr3_adr).version
+        sr4_adr = self.gc.header(tr4_adr).version
+        s_adr   = self.gc.header(t_adr  ).version
+        assert len(set([sr3_adr, sr4_adr, s_adr])) == 3
+        #
+        sr3 = llmemory.cast_adr_to_ptr(sr3_adr, lltype.Ptr(SR))
+        sr4 = llmemory.cast_adr_to_ptr(sr4_adr, lltype.Ptr(SR))
+        s   = llmemory.cast_adr_to_ptr(s_adr,   lltype.Ptr(S))
+        assert tr1.sr2 == sr3; assert tr1.sr3 == sr1     # roots: local obj
+        assert tr2.sr2 == sr3; assert tr2.sr3 == sr3     #        is modified
+        assert sr3.sr2 == sr4; assert sr3.sr3 == sr2     # non-roots: global
+        assert sr4.sr2 == sr3; assert sr4.sr3 == sr3     #      obj is modified
+        assert sr4.s1 == s


More information about the pypy-commit mailing list