[pypy-commit] pypy tealet: Saving and restoring the shadowstack in the JIT.

arigo noreply at buildbot.pypy.org
Wed Jul 6 20:28:48 CEST 2011


Author: Armin Rigo <arigo at tunes.org>
Branch: tealet
Changeset: r45385:841c213583c6
Date: 2011-06-13 10:51 +0200
http://bitbucket.org/pypy/pypy/changeset/841c213583c6/

Log:	Saving and restoring the shadowstack in the JIT.

diff --git a/pypy/jit/backend/llsupport/gc.py b/pypy/jit/backend/llsupport/gc.py
--- a/pypy/jit/backend/llsupport/gc.py
+++ b/pypy/jit/backend/llsupport/gc.py
@@ -1,7 +1,7 @@
 import os
 from pypy.rlib import rgc
 from pypy.rlib.objectmodel import we_are_translated
-from pypy.rlib.debug import fatalerror
+from pypy.rlib.debug import fatalerror, ll_assert
 from pypy.rlib.rarithmetic import ovfcheck
 from pypy.rpython.lltypesystem import lltype, llmemory, rffi, rclass, rstr
 from pypy.rpython.lltypesystem import llgroup
@@ -365,12 +365,24 @@
         self.force_index_ofs = gcdescr.force_index_ofs
 
     def add_jit2gc_hooks(self, jit2gc):
+        INTARRAYPTR = self.INTARRAYPTR
+        def read(addr):
+            return rffi.cast(INTARRAYPTR, addr)[0]
+        def write(addr, newvalue):
+            rffi.cast(INTARRAYPTR, addr)[0] = newvalue
+        # for tests:
+        read  = jit2gc.get('test_read',  read)
+        write = jit2gc.get('test_write', write)
+        cast_int_to_adr = jit2gc.get('test_i2a', llmemory.cast_int_to_adr)
+        cast_int_to_ptr = jit2gc.get('test_i2p', lltype.cast_int_to_ptr)
+        cast_ptr_to_int = jit2gc.get('test_p2i', lltype.cast_ptr_to_int)
         #
-        def collect_jit_stack_root(callback, gc, addr):
-            if addr.signed[0] != GcRootMap_shadowstack.MARKER:
+        def collect_jit_stack_root(callback, gc, realaddr):
+            addr = rffi.cast(lltype.Signed, realaddr)
+            if read(addr) != GcRootMap_shadowstack.MARKER:
                 # common case
-                if gc.points_to_valid_gc_object(addr):
-                    callback(gc, addr)
+                if gc.points_to_valid_gc_object(realaddr):
+                    callback(gc, realaddr)
                 return WORD
             else:
                 # case of a MARKER followed by an assembler stack frame
@@ -378,9 +390,8 @@
                 return 2 * WORD
         #
         def follow_stack_frame_of_assembler(callback, gc, addr):
-            frame_addr = addr.signed[1]
-            addr = llmemory.cast_int_to_adr(frame_addr + self.force_index_ofs)
-            force_index = addr.signed[0]
+            frame_addr = read(addr + WORD)
+            force_index = read(frame_addr + self.force_index_ofs)
             if force_index < 0:
                 force_index = ~force_index
             callshape = self._callshapes[force_index]
@@ -389,13 +400,145 @@
                 offset = rffi.cast(lltype.Signed, callshape[n])
                 if offset == 0:
                     break
-                addr = llmemory.cast_int_to_adr(frame_addr + offset)
+                addr = cast_int_to_adr(frame_addr + offset)
                 if gc.points_to_valid_gc_object(addr):
                     callback(gc, addr)
                 n += 1
         #
+        # ---------- tealet support ----------
+        GCPTR_ARRAY  = lltype.Ptr(lltype.GcArray(llmemory.GCREF))
+        SIGNED_ARRAY = lltype.Ptr(lltype.GcArray(lltype.Signed))
+        #
+        def save_roots(walker, gcdata):
+            gcptr_count = 0
+            signed_count = 0
+            gcptr_array = walker.gcptr_array
+            #
+            rsbase = gcdata.root_stack_base
+            rsend = gcdata.root_stack_top
+            rsaddr = rsbase
+            while rsaddr != rsend:
+                if read(rsaddr) != GcRootMap_shadowstack.MARKER:
+                    # common case
+                    if gcptr_array:
+                        gcobj = cast_int_to_ptr(llmemory.GCREF, read(rsaddr))
+                        gcptr_array[gcptr_count] = gcobj
+                    gcptr_count += 1
+                    rsaddr += WORD
+                else:
+                    # case of a MARKER followed by an assembler stack frame
+                    frame_addr = read(rsaddr + WORD)
+                    force_index = read(frame_addr + self.force_index_ofs)
+                    if force_index < 0:
+                        force_index = ~force_index
+                    if walker.signed_array:
+                        walker.signed_array[signed_count] = rsaddr - rsbase
+                        walker.signed_array[signed_count+1] = frame_addr
+                        walker.signed_array[signed_count+2] = force_index
+                        # NB. saving force_index is not necessary, but
+                        # we do it anyway because it costs little and would
+                        # find bugs
+                    signed_count += 3
+                    callshape = self._callshapes[force_index]
+                    n = 0
+                    while True:
+                        offset = rffi.cast(lltype.Signed, callshape[n])
+                        if offset == 0:
+                            break
+                        if gcptr_array:
+                            addr = cast_int_to_adr(frame_addr + offset)
+                            gcobj = cast_int_to_ptr(llmemory.GCREF, read(addr))
+                            gcptr_array[gcptr_count] = gcobj
+                        gcptr_count += 1
+                        n += 1
+                    rsaddr += 2 * WORD
+            #
+            if walker.signed_array:
+                walker.signed_array[signed_count] = rsend - rsbase
+            signed_count += 1
+            #
+            if not walker.gcptr_array:
+                walker.gcptr_array = lltype.malloc(GCPTR_ARRAY.TO, gcptr_count)
+            if not walker.signed_array:
+                walker.signed_array = lltype.malloc(SIGNED_ARRAY.TO,
+                                                    signed_count)
+            ll_assert(signed_count == len(walker.signed_array),
+                      "varying stack signed count")
+            ll_assert(gcptr_count == len(walker.gcptr_array),
+                      "varying stack gcptr count")
+        #
+        def jit_save_stack_roots(walker, gcdata):
+            """Save the stack roots from the shadowstack piece of memory,
+            including the stack roots that are in assembler-generated code
+            with a MARKER followed by the address of the assembler frame.
+            Puts all this information in two arrays: walker.gcptr_array and
+            walker.signed_array.
+            """
+            walker.gcptr_array  = lltype.nullptr(GCPTR_ARRAY.TO)
+            walker.signed_array = lltype.nullptr(SIGNED_ARRAY.TO)
+            save_roots(walker, gcdata)      # at first, just to count
+            save_roots(walker, gcdata)      # this time, really save
+        #
+        def jit_restore_stack_roots(walker, gcdata):
+            """Restore the stack roots into the shadowstack piece of memory
+            and into the assembler frames.
+            """
+            gcptr_count = 0
+            signed_count = 0
+            gcptr_array = walker.gcptr_array
+            #
+            rsbase = gcdata.root_stack_base
+            rsaddr = rsbase
+            rsmarker = rsbase + walker.signed_array[signed_count]
+            signed_count += 1
+            while True:
+                if rsaddr != rsmarker:
+                    # common case
+                    gcobj = gcptr_array[gcptr_count]
+                    write(rsaddr, cast_ptr_to_int(gcobj))
+                    gcptr_count += 1
+                    rsaddr += WORD
+                elif signed_count == len(walker.signed_array):
+                    # done
+                    break
+                else:
+                    # case of a MARKER followed by an assembler stack frame
+                    frame_addr = walker.signed_array[signed_count]
+                    write(rsaddr,        GcRootMap_shadowstack.MARKER)
+                    write(rsaddr + WORD, frame_addr)
+                    rsaddr += 2 * WORD
+                    #
+                    force_index = read(frame_addr + self.force_index_ofs)
+                    if force_index < 0:
+                        force_index = ~force_index
+                    ll_assert(force_index ==
+                              walker.signed_array[signed_count+1],
+                              "restoring bogus stack force_index")
+                    callshape = self._callshapes[force_index]
+                    n = 0
+                    while True:
+                        offset = rffi.cast(lltype.Signed, callshape[n])
+                        if offset == 0:
+                            break
+                        addr = cast_int_to_adr(frame_addr + offset)
+                        gcobj = gcptr_array[gcptr_count]
+                        write(addr, cast_ptr_to_int(gcobj))
+                        gcptr_count += 1
+                        n += 1
+                    #
+                    rsmarker = rsbase + walker.signed_array[signed_count+2]
+                    signed_count += 3
+            #
+            gcdata.root_stack_top = rsmarker
+            ll_assert(signed_count == len(walker.signed_array),
+                      "restoring bogus stack signed count")
+            ll_assert(gcptr_count == len(walker.gcptr_array),
+                      "restoring bogus stack gcptr count")
+        #
         jit2gc.update({
             'rootstackhook': collect_jit_stack_root,
+            'savestackhook': jit_save_stack_roots,
+            'restorestackhook': jit_restore_stack_roots,
             })
 
     def initialize(self):
diff --git a/pypy/jit/backend/llsupport/test/test_gc.py b/pypy/jit/backend/llsupport/test/test_gc.py
--- a/pypy/jit/backend/llsupport/test/test_gc.py
+++ b/pypy/jit/backend/llsupport/test/test_gc.py
@@ -1,4 +1,4 @@
-import random
+import sys, random
 from pypy.rpython.lltypesystem import lltype, llmemory, rffi, rstr
 from pypy.rpython.lltypesystem.lloperation import llop
 from pypy.rpython.annlowlevel import llhelper
@@ -241,6 +241,162 @@
         assert rffi.cast(lltype.Signed, p[1]) == -24
         assert rffi.cast(lltype.Signed, p[2]) == 0
 
+    def build_fake_stack(self):
+        self.gcrootmap = GcRootMap_shadowstack(self.FakeGcDescr())
+        self.gcrootmap.force_index_ofs = 16
+        self.writes = {}
+        #
+        def read_for_tests(addr):
+            assert addr % WORD == 0
+            if 3000 <= addr < 3000+8*WORD:
+                return self.shadowstack[(addr - 3000) // WORD]
+            if 20000 <= addr < 29000:
+                base = (addr // 1000) * 1000
+                return frames[base][addr-base]
+            raise AssertionError(addr)
+        def write_for_tests(addr, newvalue):
+            self.writes[addr] = newvalue
+        def cast_int_to_adr_for_tests(value):
+            return value
+        def cast_int_to_ptr_for_tests(TARGET, value):
+            assert TARGET == llmemory.GCREF
+            return lltype.opaqueptr(TARGET.TO, 'foo', x=value)
+        def cast_ptr_to_int_for_tests(value):
+            assert isinstance(value, int)
+            assert 10000 <= value < 11000 or value == 0
+            return value
+        #
+        self.jit2gc = {'test_read': read_for_tests,
+                       'test_write': write_for_tests,
+                       'test_i2a': cast_int_to_adr_for_tests,
+                       'test_i2p': cast_int_to_ptr_for_tests,
+                       'test_p2i': cast_ptr_to_int_for_tests}
+        self.gcrootmap.add_jit2gc_hooks(self.jit2gc)
+        #
+        def someobj(x):
+            return 10000 + x
+        #
+        frames = {}
+        #
+        def someframe(data, force_index):
+            num = 20000 + len(frames) * 1000
+            data[self.gcrootmap.force_index_ofs] = force_index
+            frames[num] = data
+            return num
+        #
+        MARKER = GcRootMap_shadowstack.MARKER
+        self.gcrootmap._callshapes = {61: (32, 64, 80, 0),
+                                      62: (32, 48, 0)}
+        self.shadowstack = [
+            someobj(42),
+            someobj(43),
+            0,
+            MARKER,
+            someframe({32:someobj(132), 64:someobj(164), 80:someobj(180)}, 61),
+            someobj(44),
+            MARKER,
+            someframe({32: someobj(232), 48: someobj(248)}, ~62),
+            ]
+        #
+        class FakeGC:
+            def points_to_valid_gc_object(self, addr):
+                to = read_for_tests(addr)
+                if to == 0:
+                    return False
+                if 10000 <= to < 11000:
+                    return True
+                raise AssertionError(to)
+        class FakeGCData:
+            pass
+        self.gc = FakeGC()
+        self.gcdata = FakeGCData()
+        self.gcdata.root_stack_base = 3000
+        self.gcdata.root_stack_top  = 3000 + 8*WORD
+
+    def test_jit_stack_root(self):
+        self.build_fake_stack()
+        collect_jit_stack_root = self.jit2gc['rootstackhook']
+        seen = []
+        def callback(gc, addr):
+            assert gc == self.gc
+            seen.append(addr)
+        def f(n):
+            return self.gcdata.root_stack_base + n * WORD
+        res = collect_jit_stack_root(callback, self.gc, f(0))   # someobj
+        assert res == WORD
+        assert seen == [3000]
+        res = collect_jit_stack_root(callback, self.gc, f(1))   # someobj
+        assert res == WORD
+        assert seen == [3000, 3000+WORD]
+        res = collect_jit_stack_root(callback, self.gc, f(2))   # 0
+        assert res == WORD
+        assert seen == [3000, 3000+WORD]
+        res = collect_jit_stack_root(callback, self.gc, f(3))   # MARKER
+        assert res == 2 * WORD
+        assert seen == [3000, 3000+WORD, 20032, 20064, 20080]
+        res = collect_jit_stack_root(callback, self.gc, f(5))   # someobj
+        assert res == WORD
+        assert seen == [3000, 3000+WORD, 20032, 20064, 20080, 3000+5*WORD]
+        res = collect_jit_stack_root(callback, self.gc, f(6))   # MARKER
+        assert res == 2 * WORD
+        assert seen == [3000, 3000+WORD, 20032, 20064, 20080, 3000+5*WORD,
+                        21032, 21048]
+
+    def test_jit_save_stack_roots(self):
+        class Walker:
+            pass
+        self.build_fake_stack()
+        jit_save_stack_roots = self.jit2gc['savestackhook']
+        walker = Walker()
+        jit_save_stack_roots(walker, self.gcdata)
+        assert list(walker.signed_array) == [
+            3 * WORD, 20000, 61,
+            6 * WORD, 21000, 62,
+            8 * WORD]
+        assert [gcref._obj.x for gcref in walker.gcptr_array] == [
+            10042,
+            10043,
+            0,
+            10132, 10164, 10180,
+            10044,
+            10232, 10248]
+
+    def test_jit_restore_stack_roots(self):
+        class Walker:
+            pass
+        self.build_fake_stack()
+        jit_restore_stack_roots = self.jit2gc['restorestackhook']
+        walker = Walker()
+        walker.signed_array = [
+            3 * WORD, 20000, 61,
+            6 * WORD, 21000, 62,
+            8 * WORD]
+        walker.gcptr_array = [
+            10042,
+            10043,
+            0,
+            10132, 10164, 10180,
+            10044,
+            10232, 10248]
+        self.gcdata.root_stack_top = 4444
+        jit_restore_stack_roots(walker, self.gcdata)
+        assert self.gcdata.root_stack_top == 3000 + 8*WORD
+        assert self.writes == {
+            3000 + 0*WORD: 10042,
+            3000 + 1*WORD: 10043,
+            3000 + 2*WORD: 0,
+            3000 + 3*WORD: GcRootMap_shadowstack.MARKER,
+            3000 + 4*WORD: 20000,
+            3000 + 5*WORD: 10044,
+            3000 + 6*WORD: GcRootMap_shadowstack.MARKER,
+            3000 + 7*WORD: 21000,
+            20032: 10132,
+            20064: 10164,
+            20080: 10180,
+            21032: 10232,
+            21048: 10248,
+            }
+
 
 class FakeLLOp(object):
     def __init__(self):


More information about the pypy-commit mailing list