[pypy-svn] r43568 - in pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter: . test

cfbolz at codespeak.net cfbolz at codespeak.net
Wed May 23 09:24:43 CEST 2007


Author: cfbolz
Date: Wed May 23 09:24:42 2007
New Revision: 43568

Modified:
   pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/engine.py
   pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/portal.py
   pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/term.py
   pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/test/test_jit.py
Log:
cleaning up a bit


Modified: pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/engine.py
==============================================================================
--- pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/engine.py	(original)
+++ pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/engine.py	Wed May 23 09:24:42 2007
@@ -10,11 +10,11 @@
 DEBUG = False
 
 # bytecodes:
-CALL = chr(0)
-USER_CALL = chr(1)
-TRY_RULE = chr(2)
-CONTINUATION = chr(3)
-DONE = chr(4)
+CALL = 'a'
+USER_CALL = 'u'
+TRY_RULE = 't'
+CONTINUATION = 'c'
+DONE = 'd'
 
 
 class Continuation(object):
@@ -232,8 +232,7 @@
         from pypy.lang.prolog.interpreter.parsing import parse_file
         trees = parse_file(s, self.parser, Engine._build_and_run, self)
 
-    def call(self, query, continuation=DONOTHING, choice_point=True,
-             inline=False):
+    def call(self, query, continuation=DONOTHING, choice_point=True):
         assert isinstance(query, Callable)
         if not choice_point:
             return (CALL, query, continuation, None)
@@ -263,15 +262,12 @@
         hint(where, concrete=True)
         hint(rule, concrete=True)
         while 1:
-            #hint(None, global_merge_point=True)
-            #print "  " * self.depth, where, query
             if where == DONE:
                 return next
             next = self.dispatch_bytecode(where, query, continuation, rule)
             where, query, continuation, rule = next
             where = hint(where, promote=True)
 
-
     def dispatch_bytecode(self, where, query, continuation, rule):
         if where == CALL:
             next = self._call(query, continuation)
@@ -309,15 +305,7 @@
             error.throw_existence_error(
                 "procedure", query.get_prolog_signature())
 
-        #XXX make a nice method
-        if isinstance(query, Term):
-            unify_hash = []
-            i = 0
-            while i < len(query.args):
-                unify_hash.append(query.unify_hash_of_child(i))
-                i += 1
-        else:
-            unify_hash = []
+        unify_hash = query.unify_hash_of_children(self.heap)
         rulechain = startrulechain.find_applicable_rule(unify_hash)
         if rulechain is None:
             # none of the rules apply
@@ -325,11 +313,12 @@
         rule = rulechain.rule
         rulechain = rulechain.next
         oldstate = self.heap.branch()
-        while rulechain:
-            rulechain = rulechain.find_applicable_rule(unify_hash)
-            if rulechain is None:
-                self.heap.discard(oldstate)
-                break
+        while 1:
+            if rulechain is not None:
+                rulechain = rulechain.find_applicable_rule(unify_hash)
+                choice_point = rulechain is not None
+            else:
+                choice_point = False
             hint(rule, concrete=True)
             if rule.contains_cut:
                 continuation = LimitedScopeContinuation(continuation)
@@ -345,31 +334,29 @@
                                                        continuation)
                     raise
             else:
+                inline = False #XXX rule.body is None # inline facts
                 try:
-                    result = self.try_rule(rule, query, continuation)
+                    # for the last rule (rulechain is None), this will always
+                    # return, because choice_point is False
+                    result = self.try_rule(rule, query, continuation,
+                                           choice_point=choice_point,
+                                           inline=inline)
                     self.heap.discard(oldstate)
                     return result
                 except UnificationFailed:
+                    assert choice_point
                     self.heap.revert(oldstate)
             rule = rulechain.rule
             rulechain = rulechain.next
-        hint(rule, concrete=True)
-        if rule.contains_cut:
-            continuation = LimitedScopeContinuation(continuation)
-            try:
-                return self.try_rule(rule, query, continuation)
-            except CutException, e:
-                if continuation.scope_active:
-                    return self.continue_after_cut(e.continuation, continuation)
-                raise
-        return self.try_rule(rule, query, continuation, choice_point=False)
 
     def try_rule(self, rule, query, continuation=DONOTHING, choice_point=True,
                  inline=False):
-        if not we_are_jitted():
-            return self.portal_try_rule(rule, query, continuation, choice_point)
         if not choice_point:
             return (TRY_RULE, query, continuation, rule)
+        if not we_are_jitted():
+            return self.portal_try_rule(rule, query, continuation, choice_point)
+        if inline:
+            return self.main_loop(TRY_RULE, query, continuation, rule)
         #if _is_early_constant(rule):
         #    rule = hint(rule, promote=True)
         #    return self.portal_try_rule(rule, query, continuation, choice_point)
@@ -380,20 +367,25 @@
 
     def portal_try_rule(self, rule, query, continuation, choice_point):
         hint(None, global_merge_point=True)
-        #hint(choice_point, concrete=True)
-        #if not choice_point:
-        #    return self._try_rule(rule, query, continuation)
+        hint(choice_point, concrete=True)
+        if not choice_point:
+            return self._try_rule(rule, query, continuation)
         where = TRY_RULE
         next = (DONE, None, None, None)
         hint(where, concrete=True)
         hint(rule, concrete=True)
+        signature = hint(query.signature, promote=True)
         while 1:
             hint(None, global_merge_point=True)
-            rule = hint(rule, promote=True)
             if where == DONE:
                 return next
+            if rule is not None:
+                assert rule.signature == signature
             next = self.dispatch_bytecode(where, query, continuation, rule)
             where, query, continuation, rule = next
+            rule = hint(rule, promote=True)
+            if query is not None:
+                signature = hint(query.signature, promote=True)
             where = hint(where, promote=True)
 
     def _try_rule(self, rule, query, continuation):

Modified: pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/portal.py
==============================================================================
--- pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/portal.py	(original)
+++ pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/portal.py	Wed May 23 09:24:42 2007
@@ -11,9 +11,11 @@
                 'pypy.lang.prolog.builtin.register': True
                }
 
+
 PORTAL = engine.Engine.portal_try_rule.im_func
 
 class PyrologHintAnnotatorPolicy(ManualGraphPolicy):
+    PORTAL = PORTAL
     def look_inside_graph_of_module(self, graph, mod):
         if mod in forbidden_modules:
             return False
@@ -24,6 +26,7 @@
         return True
 
     def fill_timeshift_graphs(self, t, portal_graph):
+        import pypy
         for cls in [term.Var, term.Term, term.Number, term.Float, term.Atom]:
             self.seegraph(cls.copy)
             self.seegraph(cls.__init__)
@@ -36,6 +39,7 @@
             self.seegraph(cls.get_unify_hash)
         for cls in [term.Callable, term.Atom, term.Term]:
             self.seegraph(cls.get_prolog_signature)
+            self.seegraph(cls.unify_hash_of_children)
         self.seegraph(PORTAL)
         self.seegraph(engine.Heap.newvar)
         self.seegraph(term.Rule.clone_and_unify_head)
@@ -48,10 +52,11 @@
         self.seegraph(engine.Engine.main_loop)
         self.seegraph(engine.Engine.dispatch_bytecode)
         self.seegraph(engine.LinkedRules.find_applicable_rule)
+        for method in "branch revert discard newvar extend maxvar".split():
+            self.seegraph(getattr(engine.Heap, method))
         self.seegraph(engine.Continuation.call)
-        self.seegraph(term.Term.unify_hash_of_child)
         for cls in [engine.Continuation, engine.LimitedScopeContinuation,
-                    self.pypy.lang.prolog.builtin.control.AndContinuation]:
+                    pypy.lang.prolog.builtin.control.AndContinuation]:
             self.seegraph(cls._call)
 
 def get_portal(drv):

Modified: pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/term.py
==============================================================================
--- pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/term.py	(original)
+++ pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/term.py	Wed May 23 09:24:42 2007
@@ -46,14 +46,11 @@
     def clone_compress_vars(self, vars_new_indexes, offset):
         return self
 
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
         # if two non-var objects return two different numbers
         # they must not be unifiable
         raise NotImplementedError("abstract base class")
 
-    def unify_hash_of_child(self, i):
-        raise KeyError
-
     @specialize.arg(3)
     def unify(self, other, heap, occurs_check=False):
         raise NotImplementedError("abstract base class")
@@ -147,7 +144,12 @@
         vars_new_indexes[self.index] = index
         return Var.newvar(index)
     
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
+        if heap is not None:
+            self = self.dereference(heap)
+            if isinstance(self, Var):
+                return 0
+            return self.get_unify_hash(heap)
         return 0
 
     def contains_var(self, var, heap):
@@ -220,6 +222,9 @@
     def get_prolog_signature(self):
         raise NotImplementedError("abstract base")
 
+    def unify_hash_of_children(self, heap):
+        raise NotImplementedError("abstract base")
+
 
 class Atom(Callable):
     TAG = tag()
@@ -250,16 +255,19 @@
 
     def copy_and_basic_unify(self, other, heap, memo):
         hint(self, concrete=True)
-        if isinstance(other, Atom) and (hint(self is other, promote=True) or
+        if isinstance(other, Atom) and (self is other or
                                         other.name == self.name):
             return self
         else:
             raise UnificationFailed
 
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
         name = hint(self.name, promote=True)
         return intmask(hash(name) << TAGBITS | self.TAG)
 
+    def unify_hash_of_children(self, heap):
+        return []
+
     def get_prolog_signature(self):
         return Term("/", [self, NUMBER_0])
 
@@ -301,7 +309,7 @@
     def __repr__(self):
         return "Number(%r)" % (self.num, )
 
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
         return intmask(self.num << TAGBITS | self.TAG)
 
 NUMBER_0 = Number(0)
@@ -329,7 +337,7 @@
         else:
             raise UnificationFailed
 
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
         #XXX no clue whether this is a good idea...
         m, e = math.frexp(self.num)
         m = intmask(int(m / 2 * 2 ** (32 - TAGBITS)))
@@ -364,7 +372,7 @@
         else:
             raise UnificationFailed
 
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
         return intmask(id(self) << TAGBITS | self.TAG)
 
 
@@ -461,12 +469,17 @@
         else:
             return self
 
-    def get_unify_hash(self):
+    def get_unify_hash(self, heap):
         signature = hint(self.signature, promote=True)
         return intmask(hash(signature) << TAGBITS | self.TAG)
 
-    def unify_hash_of_child(self, i):
-        return self.args[i].get_unify_hash()
+    def unify_hash_of_children(self, heap):
+        unify_hash = []
+        i = 0
+        while i < len(self.args):
+            unify_hash.append(self.args[i].get_unify_hash(heap))
+            i += 1
+        return unify_hash
 
     def get_prolog_signature(self):
         return Term("/", [Atom.newatom(self.name), Number(len(self.args))])
@@ -495,7 +508,7 @@
         self.numvars = len(d)
         self.signature = self.head.signature
         if isinstance(head, Term):
-            self.unify_hash = [arg.get_unify_hash() for arg in head.args]
+            self.unify_hash = [arg.get_unify_hash(None) for arg in head.args]
         self._does_contain_cut()
 
     def _does_contain_cut(self):

Modified: pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/test/test_jit.py
==============================================================================
--- pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/test/test_jit.py	(original)
+++ pypy/branch/prolog-jit-experiments/pypy/lang/prolog/interpreter/test/test_jit.py	Wed May 23 09:24:42 2007
@@ -39,16 +39,20 @@
 
         res = self.timeshift_from_portal(main, portal.PORTAL,
                                          [1], policy=POLICY,
-                                         backendoptimize=True)
+                                         backendoptimize=True, 
+                                         inline=0.0)
         assert res == True
         
         res = self.timeshift_from_portal(main, portal.PORTAL,
                                          [0], policy=POLICY,
-                                         backendoptimize=True)
+                                         backendoptimize=True, 
+                                         inline=0.0)
         assert res == True
 
     def test_and(self):
         e = get_engine("""
+            h(X) :- f(X).
+            h(a).
             b(a).
             a(a).
             f(X) :- b(X), a(X).
@@ -59,7 +63,7 @@
         def main(n):
             e.heap.reset()
             if n == 0:
-                e.call(term.Term("f", [X]))
+                e.call(term.Term("h", [X]))
                 return isinstance(X.dereference(e.heap), term.Atom)
             else:
                 return False
@@ -69,7 +73,8 @@
 
         res = self.timeshift_from_portal(main, portal.PORTAL,
                                          [0], policy=POLICY,
-                                         backendoptimize=True)
+                                         backendoptimize=True, 
+                                         inline=0.0)
         assert res == True
 
     def test_append(self):
@@ -93,7 +98,8 @@
         e.heap.reset()
         res = self.timeshift_from_portal(main, portal.PORTAL,
                                          [0], policy=POLICY,
-                                         backendoptimize=True)
+                                         backendoptimize=True, 
+                                         inline=0.0)
         assert res == True
 
 
@@ -120,7 +126,8 @@
 
         res = self.timeshift_from_portal(main, portal.PORTAL,
                                          [0], policy=POLICY,
-                                         backendoptimize=True)
+                                         backendoptimize=True, 
+                                         inline=0.0)
         assert res == True
 
     def test_loop(self):
@@ -141,6 +148,7 @@
 
         res = self.timeshift_from_portal(main, portal.PORTAL,
                                          [0], policy=POLICY,
-                                         backendoptimize=True)
+                                         backendoptimize=True, 
+                                         inline=0.0)
         assert res == True
 



More information about the Pypy-commit mailing list