[pypy-svn] r43952 - in pypy/dist/pypy/lang/prolog: builtin interpreter interpreter/test

cfbolz at codespeak.net cfbolz at codespeak.net
Fri Jun 1 01:18:59 CEST 2007


Author: cfbolz
Date: Fri Jun  1 01:18:58 2007
New Revision: 43952

Modified:
   pypy/dist/pypy/lang/prolog/builtin/allsolution.py
   pypy/dist/pypy/lang/prolog/builtin/exception.py
   pypy/dist/pypy/lang/prolog/builtin/formatting.py
   pypy/dist/pypy/lang/prolog/builtin/register.py
   pypy/dist/pypy/lang/prolog/builtin/termconstruction.py
   pypy/dist/pypy/lang/prolog/interpreter/engine.py
   pypy/dist/pypy/lang/prolog/interpreter/parsing.py
   pypy/dist/pypy/lang/prolog/interpreter/term.py
   pypy/dist/pypy/lang/prolog/interpreter/test/test_builtin.py
   pypy/dist/pypy/lang/prolog/interpreter/test/test_engine.py
   pypy/dist/pypy/lang/prolog/interpreter/test/test_unification.py
   pypy/dist/pypy/lang/prolog/interpreter/test/tool.py
Log:
refactor the way logic variables in the prolog interpreter are handled to be
quite a bit simpler


Modified: pypy/dist/pypy/lang/prolog/builtin/allsolution.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/builtin/allsolution.py	(original)
+++ pypy/dist/pypy/lang/prolog/builtin/allsolution.py	Fri Jun  1 01:18:58 2007
@@ -26,8 +26,7 @@
     for i in range(len(collector.found) - 1, -1, -1):
         copy = collector.found[i]
         d = {}
-        copy = copy.clone_compress_vars(d, engine.heap.maxvar())
-        engine.heap.extend(len(d))
+        copy = copy.copy(engine.heap, d)
         result = term.Term(".", [copy, result])
     bag.unify(result, engine.heap)
 expose_builtin(impl_findall, "findall", unwrap_spec=['raw', 'callable', 'raw'])

Modified: pypy/dist/pypy/lang/prolog/builtin/exception.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/builtin/exception.py	(original)
+++ pypy/dist/pypy/lang/prolog/builtin/exception.py	Fri Jun  1 01:18:58 2007
@@ -17,8 +17,7 @@
         exc_term = e.term.getvalue(engine.heap)
         engine.heap.revert(old_state)
         d = {}
-        exc_term = exc_term.clone_compress_vars(d, engine.heap.maxvar())
-        engine.heap.extend(len(d))
+        exc_term = exc_term.copy(engine.heap, d)
         try:
             impl_ground(engine, exc_term)
         except error.UnificationFailed:

Modified: pypy/dist/pypy/lang/prolog/builtin/formatting.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/builtin/formatting.py	(original)
+++ pypy/dist/pypy/lang/prolog/builtin/formatting.py	Fri Jun  1 01:18:58 2007
@@ -14,6 +14,7 @@
         self.ignore_ops = ignore_ops
         self.curr_depth = 0
         self._make_reverse_op_mapping()
+        self.var_to_number = {}
     
     def from_option_list(engine, options):
         # XXX add numbervars support
@@ -76,8 +77,11 @@
         return str(num.floatval)
 
     def format_var(self, var):
-        return "_G%s" % (var.index, )
-
+        try:
+            num = self.var_to_number[var]
+        except KeyError:
+            num = self.var_to_number[var] = len(self.var_to_number)
+        return "_G%s" % (num, )
 
     def format_term_normally(self, term):
         return "%s(%s)" % (self.format_atom(term.name),

Modified: pypy/dist/pypy/lang/prolog/builtin/register.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/builtin/register.py	(original)
+++ pypy/dist/pypy/lang/prolog/builtin/register.py	Fri Jun  1 01:18:58 2007
@@ -1,5 +1,4 @@
 import py
-from pypy.lang.prolog.interpreter import arithmetic
 from pypy.lang.prolog.interpreter.parsing import parse_file, TermBuilder
 from pypy.lang.prolog.interpreter import engine, helper, term, error
 from pypy.lang.prolog.builtin import builtins, builtins_list
@@ -66,7 +65,7 @@
         elif spec == "atom":
             code.append("    %s = helper.unwrap_atom(%s)" % (varname, varname))
         elif spec == "arithmetic":
-            code.append("    %s = arithmetic.eval_arithmetic(engine, %s)" %
+            code.append("    %s = %s.eval_arithmetic(engine)" %
                         (varname, varname))
         elif spec == "list":
             code.append("    %s = helper.unwrap_list(%s)" % (varname, varname))

Modified: pypy/dist/pypy/lang/prolog/builtin/termconstruction.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/builtin/termconstruction.py	(original)
+++ pypy/dist/pypy/lang/prolog/builtin/termconstruction.py	Fri Jun  1 01:18:58 2007
@@ -15,8 +15,6 @@
     elif isinstance(t, term.Var):
         if isinstance(functor, term.Var):
             error.throw_instantiation_error()
-        elif isinstance(functor, term.Var):
-            error.throw_instantiation_error()
         a = helper.unwrap_int(arity)
         if a < 0:
             error.throw_domain_error("not_less_than_zero", arity)
@@ -26,11 +24,8 @@
                 t.unify(helper.ensure_atomic(functor), engine.heap)
             else:
                 name = helper.unwrap_atom(functor)
-                start = engine.heap.needed_vars
-                engine.heap.extend(a)
                 t.unify(
-                    term.Term(name,
-                              [term.Var(i) for i in range(start, start + a)]),
+                    term.Term(name, [term.Var() for i in range(a)]),
                     engine.heap)
 expose_builtin(impl_functor, "functor", unwrap_spec=["obj", "obj", "obj"])
 
@@ -92,8 +87,7 @@
 
 def impl_copy_term(engine, interm, outterm):
     d = {}
-    copy = interm.clone_compress_vars(d, engine.heap.maxvar())
-    engine.heap.extend(len(d))
+    copy = interm.copy(engine.heap, d)
     outterm.unify(copy, engine.heap)
 expose_builtin(impl_copy_term, "copy_term", unwrap_spec=["obj", "obj"])
 

Modified: pypy/dist/pypy/lang/prolog/interpreter/engine.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/engine.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/engine.py	Fri Jun  1 01:18:58 2007
@@ -37,76 +37,36 @@
         self.scope_active = False
         return self.continuation.call(engine, choice_point=False)
 
-START_NUMBER_OF_VARS = 4096
-
-
 class Heap(object):
     def __init__(self):
-        self.vars = [None] * START_NUMBER_OF_VARS
         self.trail = []
-        self.needed_vars = 0
-        self.last_branch = 0
 
     def reset(self):
-        self.vars = [None] * len(self.vars)
         self.trail = []
         self.last_branch = 0
 
-    def clear(self, length):
-        l = max(START_NUMBER_OF_VARS, length)
-        self.vars = [None] * l
-        self.needed_vars = length
-        self.last_branch = length
-        self.trail = []
-
-    def getvar(self, index):
-        return self.vars[index]
-
-    def setvar(self, index, val):
-        oldval = self.vars[index]
-        self.vars[index] = val
-        # only trail for variables that have a chance to get restored
-        # on the last choice point
-        if index < self.last_branch and oldval is not val:
-            self.trail.append((index, oldval))
+    def add_trail(self, var):
+        self.trail.append((var, var.binding))
 
     def branch(self):
-        old_last_branch = self.last_branch
-        self.last_branch = self.needed_vars
-        return len(self.trail), self.needed_vars, old_last_branch
+        return len(self.trail)
 
     def revert(self, state):
-        trails, length, old_last_branch = state
-        assert length == self.last_branch
+        trails = state
         for i in range(len(self.trail) - 1, trails - 1, -1):
-            index, val = self.trail[i]
-            if index >= length:
-                val = None
-            self.vars[index] = val
-        for i in range(length, self.needed_vars):
-            self.vars[i] = None
+            var, val = self.trail[i]
+            var.binding = val
         del self.trail[trails:]
-        self.needed_vars = length
 
     def discard(self, state):
-        old_last_branch = state[2]
-        self.last_branch = old_last_branch
-
-    def extend(self, numvars):
-        if numvars:
-            self.needed_vars += numvars
-            newvars = max(0, numvars - (len(self.vars) - self.needed_vars))
-            if newvars == 0:
-                return
-            self.vars.extend([None] * (2 * newvars)) # allocate a bit more
-            assert self.needed_vars <= len(self.vars)
+        pass #XXX for now
 
     def maxvar(self):
+        XXX
         return self.needed_vars
 
     def newvar(self):
-        result = Var.newvar(self.maxvar())
-        self.extend(1)
+        result = Var(self)
         return result
 
 class LinkedRules(object):
@@ -211,8 +171,6 @@
     def run(self, query, continuation=DONOTHING):
         if not isinstance(query, Callable):
             error.throw_type_error("callable", query)
-        vars = query.get_max_var() + 1
-        self.heap.clear(vars)
         try:
             return self.call(query, continuation, choice_point=True)
         except CutException, e:
@@ -412,7 +370,7 @@
         builder = TermBuilder()
         trees = parse_file(s, self.parser)
         terms = builder.build_many(trees)
-        return terms, builder.var_to_pos
+        return terms, builder.varname_to_var
 
     def getoperations(self):
         from pypy.lang.prolog.interpreter.parsing import default_operations

Modified: pypy/dist/pypy/lang/prolog/interpreter/parsing.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/parsing.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/parsing.py	Fri Jun  1 01:18:58 2007
@@ -219,7 +219,7 @@
     s = parser_query.parse(tokens, lazy=False)
     builder = TermBuilder()
     query = builder.build(s)
-    return query, builder.var_to_pos
+    return query, builder.varname_to_var
 
 class OrderTransformer(object):
     def transform(self, node):
@@ -271,8 +271,7 @@
 class TermBuilder(RPythonVisitor):
 
     def __init__(self):
-        self.var_to_pos = {}
-        self.freevar = 0
+        self.varname_to_var = {}
 
     def build(self, s):
         "NOT_RPYTHON"
@@ -294,8 +293,7 @@
         return self.visit(s.children[0])
 
     def build_fact(self, node):
-        self.var_to_pos = {}
-        self.freevar = 0
+        self.varname_to_var = {}
         return self.visit(node.children[0])
 
     def visit(self, node):
@@ -355,14 +353,11 @@
         from pypy.lang.prolog.interpreter.term import Var
         varname = node.additional_info
         if varname == "_":
-            pos = self.freevar
-            self.freevar += 1
-            return Var.newvar(pos)
-        if varname in self.var_to_pos:
-            return self.var_to_pos[varname]
-        res = Var.newvar(self.freevar)
-        self.freevar += 1
-        self.var_to_pos[varname] = res
+            return Var()
+        if varname in self.varname_to_var:
+            return self.varname_to_var[varname]
+        res = Var()
+        self.varname_to_var[varname] = res
         return res
 
     def visit_NUMBER(self, node):

Modified: pypy/dist/pypy/lang/prolog/interpreter/term.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/term.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/term.py	Fri Jun  1 01:18:58 2007
@@ -36,18 +36,12 @@
     def dereference(self, heap):
         raise NotImplementedError("abstract base class")
 
-    def get_max_var(self):
-        return -1
-
     def copy(self, heap, memo):
         raise NotImplementedError("abstract base class")
 
     def copy_and_unify(self, other, heap, memo):
         raise NotImplementedError("abstract base class")
 
-    def clone_compress_vars(self, vars_new_indexes, offset):
-        return self
-
     def get_unify_hash(self, heap):
         # if two non-var objects return two different numbers
         # they must not be unifiable
@@ -80,12 +74,11 @@
     TAG = 0
     STANDARD_ORDER = 0
 
-    __slots__ = ('index', )
+    __slots__ = ('binding', )
     cache = {}
-    _immutable_ = True
 
-    def __init__(self, index):
-        self.index = index
+    def __init__(self, heap=None):
+        self.binding = None
 
     @specialize.arg(3)
     def unify(self, other, heap, occurs_check=False):
@@ -99,16 +92,16 @@
         elif occurs_check and other.contains_var(self, heap):
             raise UnificationFailed()
         else:
-            heap.setvar(self.index, other)
+            self.setvalue(other, heap)
 
     def dereference(self, heap):
-        next = heap.getvar(self.index)
+        next = self.binding
         if next is None:
             return self
         else:
             result = next.dereference(heap)
             # do path compression
-            heap.setvar(self.index, result)
+            self.setvalue(result, heap)
             return result
 
     def getvalue(self, heap):
@@ -117,6 +110,10 @@
             return res.getvalue(heap)
         return res
 
+    def setvalue(self, value, heap):
+        heap.add_trail(self)
+        self.binding = value
+
     def copy(self, heap, memo):
         hint(self, concrete=True)
         try:
@@ -137,17 +134,6 @@
             seen_value.unify(other, heap)
             return seen_value
 
-
-    def get_max_var(self):
-        return self.index
-
-    def clone_compress_vars(self, vars_new_indexes, offset):
-        if self.index in vars_new_indexes:
-            return Var.newvar(vars_new_indexes[self.index])
-        index = len(vars_new_indexes) + offset
-        vars_new_indexes[self.index] = index
-        return Var.newvar(index)
-    
     def get_unify_hash(self, heap):
         if heap is not None:
             self = self.dereference(heap)
@@ -165,22 +151,12 @@
         return False
 
     def __repr__(self):
-        return "Var(%s)" % (self.index, )
+        return "Var(%s)" % (self.binding, )
 
 
     def __eq__(self, other):
         # for testing
-        return (self.__class__ == other.__class__ and
-                self.index == other.index)
-
-    @staticmethod
-    @purefunction
-    def newvar(index):
-        result = Var.cache.get(index, None)
-        if result is not None:
-            return result
-        Var.cache[index] = result = Var(index)
-        return result
+        return self is other
 
     def eval_arithmetic(self, engine):
         self = self.dereference(engine.heap)
@@ -413,9 +389,6 @@
 def _clone(obj, offset):
     return obj.clone(offset)
 
-def _clone_compress_vars(obj, vars_new_indexes, offset):
-    return obj.clone_compress_vars(vars_new_indexes, offset)
-
 def _getvalue(obj, heap):
     return obj.getvalue(heap)
 
@@ -475,15 +448,6 @@
         else:
             raise UnificationFailed
 
-    def get_max_var(self):
-        result = -1
-        for subterm in self.args:
-            result = max(result, subterm.get_max_var())
-        return result
-    
-    def clone_compress_vars(self, vars_new_indexes, offset):
-        return self._copy_term(_clone_compress_vars, vars_new_indexes, offset)
-
     def getvalue(self, heap):
         return self._copy_term(_getvalue, heap)
 
@@ -542,16 +506,13 @@
     unify_hash = []
     def __init__(self, head, body):
         from pypy.lang.prolog.interpreter import helper
-        d = {}
-        head = head.clone_compress_vars(d, 0)
         assert isinstance(head, Callable)
         self.head = head
         if body is not None:
             body = helper.ensure_callable(body)
-            self.body = body.clone_compress_vars(d, 0)
+            self.body = body
         else:
             self.body = None
-        self.numvars = len(d)
         self.signature = self.head.signature
         if isinstance(head, Term):
             self.unify_hash = [arg.get_unify_hash(None) for arg in head.args]
@@ -611,7 +572,7 @@
         return c
     if isinstance(obj1, Var):
         assert isinstance(obj2, Var)
-        return rcmp(obj1.index, obj2.index)
+        return rcmp(id(obj1), id(obj2))
     if isinstance(obj1, Atom):
         assert isinstance(obj2, Atom)
         return rcmp(obj1.name, obj2.name)

Modified: pypy/dist/pypy/lang/prolog/interpreter/test/test_builtin.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/test/test_builtin.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/test/test_builtin.py	Fri Jun  1 01:18:58 2007
@@ -105,12 +105,12 @@
     assert_true("assertz(f(a, a)).", e)
     assert_true("A = a, asserta(h(A, A)).", e)
     f = assert_true("g(B, B).", e)
-    assert f.vars[0].name == "b"
+    assert f['B'].name == "b"
     f = assert_true("f(B, B).", e)
-    assert f.vars[0].name == "b"
+    assert f['B'].name == "b"
     assert_false("h(c, c).", e)
     f = assert_true("h(B, B).", e)
-    assert f.vars[0].name == "a"
+    assert f['B'].name == "a"
 
 def test_assert_logical_update_view():
     e = get_engine("""
@@ -281,7 +281,7 @@
     assert_false("between(12, 15, 16).")
     heaps = collect_all(Engine(), "between(1, 4, X).")
     assert len(heaps) == 4
-    assert heaps[0].vars[0].num == 1
+    assert heaps[0]['X'].num == 1
 
 def test_is():
     assert_true("5 is 1 + 1 + 1 + 1 + 1.")
@@ -319,10 +319,12 @@
 def test_standard_comparison():
     assert_true("X = Y, f(X, Y, X, Y) == f(X, X, Y, Y).")
     assert_true("X = Y, f(X, Y, X, Z) \\== f(X, X, Y, Y).")
-    assert_true("X @< Y, X @=< X, X @=< Y, Y @> X.")
+    assert_true("""X \\== Y, ((X @< Y, X @=< X, X @=< Y, Y @> X);
+                              (X @> Y, X @>= X, X @>= Y, Y @< X)).""")
     assert_true("'\\\\=='(f(X, Y), 12).")
     assert_true("X = f(a), Y = f(b), Y @> X.")
 
+
 def test_atom_length():
     assert_true("atom_length('abc', 3).")
     assert_true("atom_length('\\\\', 1).")

Modified: pypy/dist/pypy/lang/prolog/interpreter/test/test_engine.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/test/test_engine.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/test/test_engine.py	Fri Jun  1 01:18:58 2007
@@ -1,6 +1,7 @@
 import py
 from pypy.lang.prolog.interpreter.parsing import parse_file, TermBuilder
 from pypy.lang.prolog.interpreter.parsing import parse_query_term, get_engine
+from pypy.lang.prolog.interpreter.parsing import get_query_and_vars
 from pypy.lang.prolog.interpreter.error import UnificationFailed, CatchableError
 from pypy.lang.prolog.interpreter.test.tool import collect_all, assert_true, assert_false
 from pypy.lang.prolog.interpreter.test.tool import prolog_raises
@@ -9,8 +10,9 @@
     e = get_engine("""
         f(a).
     """)
-    e.run(parse_query_term("f(X)."))
-    assert e.heap.getvar(0).name == "a"
+    t, vars = get_query_and_vars("f(X).")
+    e.run(t)
+    assert vars['X'].dereference(e.heap).name == "a"
 
 def test_and():
     e = get_engine("""
@@ -20,9 +22,9 @@
         f(X, Z) :- g(X, Y), g(Y, Z).
     """)
     e.run(parse_query_term("f(a, c)."))
-    e.run(parse_query_term("f(X, c)."))
-    print e.heap.vars[:10]
-    assert e.heap.getvar(0).name == "a"
+    t, vars = get_query_and_vars("f(X, c).")
+    e.run(t)
+    assert vars['X'].dereference(e.heap).name == "a"
 
 def test_and_long():
     e = get_engine("""
@@ -52,8 +54,9 @@
         return "succ(%s)" % nstr(n - 1)
     e.run(parse_query_term("num(0)."))
     e.run(parse_query_term("num(succ(0))."))
-    e.run(parse_query_term("num(X)."))
-    assert e.heap.getvar(0).num == 0
+    t, vars = get_query_and_vars("num(X).")
+    e.run(t)
+    assert vars['X'].dereference(e.heap).num == 0
     e.run(parse_query_term("add(0, 0, 0)."))
     py.test.raises(UnificationFailed, e.run, parse_query_term("""
         add(0, 0, succ(0))."""))
@@ -88,8 +91,9 @@
         g(a, a).
         f(X, Y, Z) :- (g(X, Z); g(X, Z); g(Z, Y)), a(Z).
         """)
-    e.run(parse_query_term("f(a, b, Z)."))
-    assert e.heap.getvar(0).name == "a"
+    t, vars = get_query_and_vars("f(a, b, Z).")
+    e.run(t)
+    assert vars['Z'].dereference(e.heap).name == "a"
     f = collect_all(e, "X = 1; X = 2.")
     assert len(f) == 2
 
@@ -127,9 +131,9 @@
     """)
     heaps = collect_all(e, "g(X).")
     assert len(heaps) == 3
-    assert heaps[0].getvar(0).name == "a"
-    assert heaps[1].getvar(0).name == "b"
-    assert heaps[2].getvar(0).name == "c"
+    assert heaps[0]['X'].name == "a"
+    assert heaps[1]['X'].name == "b"
+    assert heaps[2]['X'].name == "c"
 
 def test_cut():
     e = get_engine("""

Modified: pypy/dist/pypy/lang/prolog/interpreter/test/test_unification.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/test/test_unification.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/test/test_unification.py	Fri Jun  1 01:18:58 2007
@@ -11,41 +11,35 @@
     py.test.raises(UnificationFailed, "a.unify(Atom.newatom('xxx'), None)")
 
 def test_var():
-    b = Var.newvar(0)
+    b = Var()
     heap = Heap()
-    heap.clear(1)
     b.unify(Atom.newatom("hallo"), heap)
     assert b.getvalue(heap).name == "hallo"
-    a = Var.newvar(0)
-    b = Var.newvar(1)
-    heap.clear(2)
+    a = Var()
+    b = Var()
     a.unify(b, heap)
     a.unify(Atom.newatom("hallo"), heap)
     assert a.getvalue(heap).name == "hallo"
     assert b.getvalue(heap).name == "hallo"
 
 def test_unify_var():
-    b = Var.newvar(0)
+    b = Var()
     heap = Heap()
-    heap.clear(1)
     b.unify(b, heap)
     b.unify(Atom.newatom("hallo"), heap)
     py.test.raises(UnificationFailed, b.unify, Atom.newatom("bye"), heap)
 
 def test_recursive():
-    b = Var.newvar(0)
+    b = Var()
     heap = Heap()
-    heap.clear(1)
     b.unify(Term("hallo", [b]), heap)
-    
 
 def test_term():
-    X = Var.newvar(0)
-    Y = Var.newvar(1)
+    X = Var()
+    Y = Var()
     t1 = Term("f", [Atom.newatom("hallo"), X])
     t2 = Term("f", [Y, Atom.newatom("HALLO")])
     heap = Heap()
-    heap.clear(2)
     print t1, t2
     t1.unify(t2, heap)
     assert X.getvalue(heap).name == "HALLO"
@@ -61,9 +55,11 @@
 def test_run():
     e = Engine()
     e.add_rule(Term("f", [Atom.newatom("a"), Atom.newatom("b")]))
-    e.add_rule(Term("f", [Var.newvar(0), Var.newvar(0)]))
-    e.add_rule(Term(":-", [Term("f", [Var.newvar(0), Var.newvar(1)]),
-                           Term("f", [Var.newvar(1), Var.newvar(0)])]))
+    X = Var()
+    Y = Var()
+    e.add_rule(Term("f", [X, X]))
+    e.add_rule(Term(":-", [Term("f", [X, Y]),
+                           Term("f", [Y, X])]))
     X = e.heap.newvar()
     e.run(Term("f", [Atom.newatom("b"), X]))
     assert X.dereference(e.heap).name == "b"

Modified: pypy/dist/pypy/lang/prolog/interpreter/test/tool.py
==============================================================================
--- pypy/dist/pypy/lang/prolog/interpreter/test/tool.py	(original)
+++ pypy/dist/pypy/lang/prolog/interpreter/test/tool.py	Fri Jun  1 01:18:58 2007
@@ -6,12 +6,11 @@
 def assert_true(query, e=None):
     if e is None:
         e = Engine()
-    term = e.parse(query)[0][0]
+    terms, vars = e.parse(query)
+    term, = terms
     e.run(term)
-    f = Heap()
-    f.vars = e.heap.vars[:]
-    return f
-
+    return dict([(name, var.dereference(e.heap))
+                     for name, var in vars.iteritems()])
 def assert_false(query, e=None):
     if e is None:
         e = Engine()
@@ -23,20 +22,20 @@
                        (query, exc), e)
 
 class CollectAllContinuation(Continuation):
-    def __init__(self):
+    def __init__(self, vars):
         self.heaps = []
+        self.vars = vars
 
     def _call(self, engine):
-        f = Heap()
-        f.vars = engine.heap.vars[:]
-        self.heaps.append(f)
-#        import pdb; pdb.set_trace()
+        self.heaps.append(dict([(name, var.dereference(engine.heap))
+                                    for name, var in self.vars.iteritems()]))
         print "restarting computation"
         raise UnificationFailed
 
 def collect_all(engine, s):
-    collector = CollectAllContinuation()
-    term = engine.parse(s)[0][0]
+    terms, vars = engine.parse(s)
+    term, = terms
+    collector = CollectAllContinuation(vars)
     py.test.raises(UnificationFailed, engine.run, term,
                    collector)
     return collector.heaps



More information about the Pypy-commit mailing list