[pypy-svn] r22537 - pypy/dist/pypy/lib/logic

auc at codespeak.net auc at codespeak.net
Mon Jan 23 17:56:55 CET 2006


Author: auc
Date: Mon Jan 23 17:56:53 2006
New Revision: 22537

Added:
   pypy/dist/pypy/lib/logic/constraint.py
Modified:
   pypy/dist/pypy/lib/logic/test_unification.py
   pypy/dist/pypy/lib/logic/unification.py
   pypy/dist/pypy/lib/logic/variable.py
Log:
(auc, ale)
* many more misc. checks
* add constraint stuff


Added: pypy/dist/pypy/lib/logic/constraint.py
==============================================================================
--- (empty file)
+++ pypy/dist/pypy/lib/logic/constraint.py	Mon Jan 23 17:56:53 2006
@@ -0,0 +1,369 @@
+# a) new requirement : be able to postpone asking fo the
+# values of the domain
+
+class ConsistencyFailure(Exception):
+    """The repository is not in a consistent state"""
+    pass
+
+
+
+#-- Domains --------------------------------------------
+
+class AbstractDomain(object):
+    """Implements the functionnality related to the changed flag.
+    Can be used as a starting point for concrete domains"""
+
+    #__implements__ = DomainInterface
+    def __init__(self):
+        self.__changed = 0
+
+    def reset_flags(self):
+        self.__changed = 0
+    
+    def has_changed(self):
+        return self.__changed
+
+    def _value_removed(self):
+        """The implementation of remove_value should call this method"""
+        self.__changed = 1
+        if self.size() == 0:
+            raise ConsistencyFailure()
+
+
+class FiniteDomain(AbstractDomain):
+    """
+    Variable Domain with a finite set of possible values
+    """
+
+    _copy_count = 0
+    _write_count = 0
+    
+    def __init__(self, values):
+        """values is a list of values in the domain
+        This class uses a dictionnary to make sure that there are
+        no duplicate values"""
+        AbstractDomain.__init__(self)
+        if isinstance(values,FiniteDomain):
+            # do a copy on write
+            self._cow = True
+            values._cow = True
+            FiniteDomain._copy_count += 1
+            self._values = values._values
+        else:
+            # don't check this there (a)
+            #assert len(values) > 0
+            self.set_values(values)
+            
+        ##self.getValues = self._values.keys
+
+    def set_values(self, values):
+        self._cow = False
+        FiniteDomain._write_count += 1
+        self._values = set(values)
+        
+    def remove_value(self, value):
+        """Remove value of domain and check for consistency"""
+##         print "removing", value, "from", self._values.keys()
+        if self._cow:
+            self.set_values(self._values)
+        del self._values[value]
+        self._value_removed()
+
+    def remove_values(self, values):
+        """Remove values of domain and check for consistency"""
+        if self._cow:
+            self.set_values(self._values)
+        if values:
+##             print "removing", values, "from", self._values.keys()
+            for val in values :
+                self._values.remove(val)
+            self._value_removed()
+    __delitem__ = remove_value
+    
+    def size(self):
+        """computes the size of a finite domain"""
+        return len(self._values)
+    __len__ = size
+    
+    def get_values(self):
+        """return all the values in the domain"""
+        return self._values
+
+    def __iter__(self):
+        return iter(self._values)
+    
+    def copy(self):
+        """clone the domain"""
+        return FiniteDomain(self)
+    
+    def __repr__(self):
+        return '<FiniteDomain %s>' % str(self.get_values())
+
+    def __eq__(self, other):
+        if other is None: return False
+        return self._values == other._values
+
+    def intersection(self, other):
+        if other is None: return self.get_values()
+        return self.get_values() & other.get_values()
+
+#-- Constraints ------------------------------------------
+
+class AbstractConstraint(object):
+    #__implements__ = ConstraintInterface
+    
+    def __init__(self, variables):
+        """variables is a list of variables which appear in the formula"""
+        self._variables = variables
+
+    def affectedVariables(self):
+        """ Return a list of all variables affected by this constraint """
+        return self._variabl
+
+    def isVariableRelevant(self, variable):
+        return variable in self._variables
+
+    def estimateCost(self, domains):
+        """Return an estimate of the cost of the narrowing of the constraint"""
+        return reduce(operator.mul,
+                      [domains[var].size() for var in self._variables])
+
+
+class BasicConstraint(object):
+    """A BasicConstraint, which is never queued by the Repository
+    A BasicConstraint affects only one variable, and will be entailed
+    on the first call to narrow()"""
+    
+    def __init__(self, variable, reference, operator):
+        """variables is a list of variables on which
+        the constraint is applied"""
+        self._variable = variable
+        self._reference = reference
+        self._operator = operator
+
+    def __repr__(self):
+        return '<%s %s %s>'% (self.__class__, self._variable, self._reference)
+
+    def isVariableRelevant(self, variable):
+        return variable == self._variable
+
+    def estimateCost(self, domains):
+        return 0 # get in the first place in the queue
+    
+    def affectedVariables(self):
+        return [self._variable]
+    
+    def getVariable(self):
+        return self._variable
+        
+    def narrow(self, domains):
+        domain = domains[self._variable]
+        operator = self._operator
+        ref = self._reference
+        try:
+            for val in domain.get_values() :
+                if not operator(val, ref) :
+                    domain.remove_value(val)
+        except ConsistencyFailure:
+            raise ConsistencyFailure('inconsistency while applying %s' % \
+                                     repr(self))
+        return 1
+
+
+
+class Expression(AbstractConstraint):
+    """A constraint represented as a python expression."""
+    _FILTER_CACHE = {}
+
+    def __init__(self, variables, formula, type='fd.Expression'):
+        """variables is a list of variables which appear in the formula
+        formula is a python expression that will be evaluated as a boolean"""
+        AbstractConstraint.__init__(self, variables)
+        self.formula = formula
+        self.type = type
+        try:
+            self.filterFunc = Expression._FILTER_CACHE[formula]
+        except KeyError:
+            self.filterFunc = eval('lambda %s: %s' % \
+                                        (','.join(variables), formula), {}, {})
+            Expression._FILTER_CACHE[formula] = self.filterFunc
+
+    def _init_result_cache(self):
+        """key = (variable,value), value = [has_success,has_failure]"""
+        result_cache = {}
+        for var_name in self._variables:
+            result_cache[var_name] = {}
+        return result_cache
+
+
+    def _assign_values(self, domains):
+        variables = []
+        kwargs = {}
+        for variable in self._variables:
+            domain = domains[variable]
+            values = domain.get_values()
+            variables.append((domain.size(), [variable, values, 0, len(values)]))
+            kwargs[variable] = values[0]
+        # sort variables to instanciate those with fewer possible values first
+        variables.sort()
+
+        go_on = 1
+        while go_on:
+            yield kwargs
+            # try to instanciate the next variable
+            for size, curr in variables:
+                if (curr[2] + 1) < curr[-1]:
+                    curr[2] += 1
+                    kwargs[curr[0]] = curr[1][curr[2]]
+                    break
+                else:
+                    curr[2] = 0
+                    kwargs[curr[0]] = curr[1][0]
+            else:
+                # it's over
+                go_on = 0
+            
+        
+    def narrow(self, domains):
+        """generic narrowing algorithm for n-ary expressions"""
+        maybe_entailed = 1
+        ffunc = self.filterFunc
+        result_cache = self._init_result_cache()
+        for kwargs in self._assign_values(domains):
+            if maybe_entailed:
+                for var, val in kwargs.iteritems():
+                    if val not in result_cache[var]:
+                        break
+                else:
+                    continue
+            if ffunc(**kwargs):
+                for var, val in kwargs.items():
+                    result_cache[var][val] = 1
+            else:
+                maybe_entailed = 0
+
+        try:
+            for var, keep in result_cache.iteritems():
+                domain = domains[var]
+                domain.remove_values([val for val in domain if val not in keep])
+                
+        except ConsistencyFailure:
+            raise ConsistencyFailure('Inconsistency while applying %s' % \
+                                     repr(self))
+        except KeyError:
+            # There are no more value in result_cache
+            pass
+
+        return maybe_entailed
+
+    def __repr__(self):
+        return '<%s "%s">' % (self.type, self.formula)
+
+class BinaryExpression(Expression):
+    """A binary constraint represented as a python expression
+
+    This implementation uses a narrowing algorithm optimized for
+    binary constraints."""
+    
+    def __init__(self, variables, formula, type = 'fd.BinaryExpression'):
+        assert len(variables) == 2
+        Expression.__init__(self, variables, formula, type)
+
+    def narrow(self, domains):
+        """specialized narrowing algorithm for binary expressions
+        Runs much faster than the generic version"""
+        maybe_entailed = 1
+        var1 = self._variables[0]
+        dom1 = domains[var1]
+        values1 = dom1.get_values()
+        var2 = self._variables[1]
+        dom2 = domains[var2]
+        values2 = dom2.get_values()
+        ffunc = self.filterFunc
+        if dom2.size() < dom1.size():
+            var1, var2 = var2, var1
+            dom1, dom2 = dom2, dom1
+            values1, values2 = values2, values1
+            
+        kwargs = {}
+        keep1 = {}
+        keep2 = {}
+        maybe_entailed = 1
+        try:
+            # iterate for all values
+            for val1 in values1:
+                kwargs[var1] = val1
+                for val2 in values2:
+                    kwargs[var2] = val2
+                    if val1 in keep1 and val2 in keep2 and maybe_entailed == 0:
+                        continue
+                    if ffunc(**kwargs):
+                        keep1[val1] = 1
+                        keep2[val2] = 1
+                    else:
+                        maybe_entailed = 0
+
+            dom1.remove_values([val for val in values1 if val not in keep1])
+            dom2.remove_values([val for val in values2 if val not in keep2])
+            
+        except ConsistencyFailure:
+            raise ConsistencyFailure('Inconsistency while applying %s' % \
+                                     repr(self))
+        except Exception:
+            print self, kwargs
+            raise 
+        return maybe_entailed
+
+
+def make_expression(variables, formula, constraint_type=None):
+    """create a new constraint of type Expression or BinaryExpression
+    The chosen class depends on the number of variables in the constraint"""
+    # encode unicode
+    vars = []
+    for var in variables:
+        if type(var) == type(u''):
+            vars.append(var.encode())
+        else:
+            vars.append(var)
+    if len(vars) == 2:
+        if constraint_type is not None:
+            return BinaryExpression(vars, formula, constraint_type)
+        else:
+            return BinaryExpression(vars, formula)
+
+    else:
+        if constraint_type is not None:
+            return Expression(vars, formula, constraint_type)
+        else:
+            return Expression(vars, formula)
+
+
+class Equals(BasicConstraint):
+    """A basic constraint variable == constant value"""
+    def __init__(self, variable, reference):
+        BasicConstraint.__init__(self, variable, reference, operator.eq)
+
+class NotEquals(BasicConstraint):
+    """A basic constraint variable != constant value"""
+    def __init__(self, variable, reference):
+        BasicConstraint.__init__(self, variable, reference, operator.ne)
+
+class LesserThan(BasicConstraint):
+    """A basic constraint variable < constant value"""
+    def __init__(self, variable, reference):
+        BasicConstraint.__init__(self, variable, reference, operator.lt)
+
+class LesserOrEqual(BasicConstraint):
+    """A basic constraint variable <= constant value"""
+    def __init__(self, variable, reference):
+        BasicConstraint.__init__(self, variable, reference, operator.le)
+
+class GreaterThan(BasicConstraint):
+    """A basic constraint variable > constant value"""
+    def __init__(self, variable, reference):
+        BasicConstraint.__init__(self, variable, reference, operator.gt)
+
+class GreaterOrEqual(BasicConstraint):
+    """A basic constraint variable >= constant value"""
+    def __init__(self, variable, reference):
+        BasicConstraint.__init__(self, variable, reference, operator.ge)

Modified: pypy/dist/pypy/lib/logic/test_unification.py
==============================================================================
--- pypy/dist/pypy/lib/logic/test_unification.py	(original)
+++ pypy/dist/pypy/lib/logic/test_unification.py	Mon Jan 23 17:56:53 2006
@@ -1,5 +1,6 @@
 import unification as u
 import variable as v
+from constraint import FiniteDomain
 from py.test import raises, skip
 from threading import Thread
 
@@ -46,11 +47,23 @@
         assert z.val == 3.14
 
     def test_unify_same(self):
-        x,y,z = (u.var('x'), u.var('y'), u.var('z'))
+        x,y,z,w = (u.var('x'), u.var('y'),
+                   u.var('z'), u.var('w'))
         u.bind(x, [42, z])
         u.bind(y, [z, 42])
+        u.bind(w, [z, 43])
+        raises(u.UnificationFailure, u.unify, x, w)
+        u.unify(x, y)
+        assert z.val == 42
+
+    def test_double_unification(self):
+        x, y, z = (u.var('x'), u.var('y'),
+                   u.var('z'))
+        u.bind(x, 42)
+        u.bind(y, z)
         u.unify(x, y)
         assert z.val == 42
+        #raises(u.UnificationFailure, u.unify, x, y)
 
     def test_unify_values(self):
         x, y = u.var('x'), u.var('y')
@@ -199,7 +212,43 @@
         t2.start()
         t1.join()
         t2.join()
-        assert z.val == 0
+        print "Z", z
         assert (t2.raised and not t1.raised) or \
                (t1.raised and not t2.raised)
+        assert z.val == 0
             
+
+    def test_set_var_domain(self):
+        x = u.var('x')
+        u.set_domain(x, [1, 3, 5])
+        assert x.dom == FiniteDomain([1, 3, 5])
+        assert u._store.domains[x] == FiniteDomain([1, 3, 5])
+
+    def test_bind_with_domain(self):
+        x = u.var('x')
+        u.set_domain(x, [1, 2, 3])
+        raises(u.OutOfDomain, u.bind, x, 42)
+        u.bind(x, 3)
+        assert x.val == 3
+
+    def test_bind_with_incompatible_domains(self):
+        x, y = u.var('x'), u.var('y')
+        u.set_domain(x, [1, 2])
+        u.set_domain(y, [3, 4])
+        raises(u.IncompatibleDomains, u.bind, x, y)
+        u.set_domain(y, [2, 4])
+        u.bind(x, y)
+        # check x and y are in the same equiv. set
+        assert x.val == y.val
+
+
+    def test_unify_with_domains(self):
+        x,y,z = u.var('x'), u.var('y'), u.var('z')
+        u.bind(x, [42, z])
+        u.bind(y, [z, 42])
+        u.set_domain(z, [1, 2, 3])
+        raises(u.UnificationFailure, u.unify, x, y)
+        u.set_domain(z, [41, 42, 43])
+        u.unify(x, y)
+        assert z.val == 42
+        assert z.dom == FiniteDomain([41, 42, 43])

Modified: pypy/dist/pypy/lib/logic/unification.py
==============================================================================
--- pypy/dist/pypy/lib/logic/unification.py	(original)
+++ pypy/dist/pypy/lib/logic/unification.py	Mon Jan 23 17:56:53 2006
@@ -121,6 +121,7 @@
 import threading
 
 from variable import EqSet, Var, VariableException, NotAVariable
+from constraint import FiniteDomain
 
 #----------- Store Exceptions ----------------------------
 class UnboundVariable(VariableException):
@@ -135,13 +136,27 @@
     def __str__(self):
         return "%s already in store" % self.name
 
+class OutOfDomain(VariableException):
+    def __str__(self):
+        return "value not in domain of %s" % self.name
+
 class UnificationFailure(Exception):
+    def __init__(self, var1, var2, cause=None):
+        self.var1, self.var2 = (var1, var2)
+        self.cause = cause
+    def __str__(self):
+        diag = "%s %s can't be unified"
+        if self.cause:
+            diag += " because %s" % self.cause
+        return diag % (self.var1, self.var2)
+        
+class IncompatibleDomains(Exception):
     def __init__(self, var1, var2):
         self.var1, self.var2 = (var1, var2)
     def __str__(self):
-        return "%s %s can't be unified" % (self.var1,
-                                           self.var2)
-              
+        return "%s %s have incompatible domains" % \
+               (self.var1, self.var2)
+    
 #----------- Store ------------------------------------
 class Store(object):
     """The Store consists of a set of k variables
@@ -157,6 +172,11 @@
         # mapping of names to vars (all of them)
         self.vars = set()
         self.names = set()
+        # mapping of vars to domains
+        self.domains = {}
+        # mapping of names to constraints (all...)
+        self.contraints = {}
+        # consistency-preserving stuff
         self.in_transaction = False
         self.lock = threading.Lock()
 
@@ -170,6 +190,15 @@
         # put into new singleton equiv. set
         var.val = EqSet([var])
 
+    #-- Bind var to domain --------------------
+
+    def set_domain(self, var, dom):
+        assert(isinstance(var, Var) and (var in self.vars))
+        if var.is_bound():
+            raise AlreadyBound
+        var.dom = FiniteDomain(dom)
+        self.domains[var] = var.dom
+
     #-- BIND -------------------------------------------
 
     def bind(self, var, val):
@@ -177,39 +206,54 @@
            2. (unbound)Variable/(bound)Variable or
            3. (unbound)Variable/Value binding
         """
-        self.lock.acquire()
-        assert(isinstance(var, Var) and (var in self.vars))
-        if var == val:
-            return
-        if _both_are_vars(var, val):
-            if _both_are_bound(var, val):
-                raise AlreadyBound(var.name)
-            if var._is_bound(): # 2b. var is bound, not var
-                self.bind(val, var)
-            elif val._is_bound(): # 2a.val is bound, not val
-                self._bind(var.val, val.val)
-            else: # 1. both are unbound
-                self._merge(var.val, val.val)
-        else: # 3. val is really a value
-            if var._is_bound():
-                raise AlreadyBound(var.name)
-            self._bind(var.val, val)
-        self.lock.release()
+        try:
+            self.lock.acquire()
+            assert(isinstance(var, Var) and (var in self.vars))
+            if var == val:
+                return
+            if _both_are_vars(var, val):
+                if _both_are_bound(var, val):
+                    raise AlreadyBound(var.name)
+                if var._is_bound(): # 2b. var is bound, not var
+                    self.bind(val, var)
+                elif val._is_bound(): # 2a.var is bound, not val
+                    self._bind(var.val, val.val)
+                else: # 1. both are unbound
+                    self._merge(var, val)
+            else: # 3. val is really a value
+                if var._is_bound():
+                    raise AlreadyBound(var.name)
+                self._bind(var.val, val)
+        finally:
+            self.lock.release()
 
 
     def _bind(self, eqs, val):
         # print "variable - value binding : %s %s" % (eqs, val)
-        # bind all vars in the eqset to obj
+        # bind all vars in the eqset to val
         for var in eqs:
+            if var.dom != None:
+                if val not in var.dom.get_values():
+                    # undo the half-done binding
+                    for v in eqs:
+                        v.val = eqs
+                    raise OutOfDomain(var)
             var.val = val
 
-    def _merge(self, eqs1, eqs2):
+    def _merge(self, v1, v2):
+        for v in v1.val:
+            if not _compatible_domains(v, v2.val):
+                raise IncompatibleDomains(v1, v2)
+        self._really_merge(v1.val, v2.val)
+
+    def _really_merge(self, eqs1, eqs2):
         # print "unbound variables binding : %s %s" % (eqs1, eqs2)
         if eqs1 == eqs2: return
         # merge two equisets into one
         eqs1 |= eqs2
-        # let's reassign everybody to neweqs
-        self._bind(eqs1, eqs1)
+        # let's reassign everybody to the merged eq
+        for var in eqs1:
+            var.val = eqs1
 
     #-- UNIFY ------------------------------------------
 
@@ -221,11 +265,13 @@
                 for var in self.vars:
                     if var.changed:
                         var._commit()
-            except:
+            except Exception, cause:
                 for var in self.vars:
                     if var.changed:
                         var._abort()
-                raise
+                if isinstance(cause, UnificationFailure):
+                    raise
+                raise UnificationFailure(x, y, cause)
         finally:
             self.in_transaction = False
 
@@ -242,7 +288,7 @@
             self._unify_var_val(x, y)
         elif _both_are_bound(x, y):
             self._unify_bound(x,y)
-        elif x.isbound():
+        elif x._is_bound():
             self.bind(x,y)
         else:
             self.bind(y,x)
@@ -279,6 +325,20 @@
         for xk in vx.keys():
             self._really_unify(vx[xk], vy[xk])
 
+
+def _compatible_domains(var, eqs):
+    """check that the domain of var is compatible
+       with the domains of the vars in the eqs
+    """
+    if var.dom == None: return True
+    empty = set()
+    for v in eqs:
+        if v.dom == None: continue
+        if v.dom.intersection(var.dom) == empty:
+            return False
+    return True
+
+
 #-- Unifiability checks---------------------------------------
 #--
 #-- quite costly & could be merged back in unify
@@ -293,6 +353,7 @@
 _unifiable_memo = set()
 
 def _unifiable(term1, term2):
+    global _unifiable_memo
     _unifiable_memo = set()
     return _really_unifiable(term1, term2)
         
@@ -351,6 +412,9 @@
     _store.add_unbound(v)
     return v
 
+def set_domain(var, dom):
+    return _store.set_domain(var, dom)
+
 def bind(var, val):
     return _store.bind(var, val)
 

Modified: pypy/dist/pypy/lib/logic/variable.py
==============================================================================
--- pypy/dist/pypy/lib/logic/variable.py	(original)
+++ pypy/dist/pypy/lib/logic/variable.py	Mon Jan 23 17:56:53 2006
@@ -37,6 +37,8 @@
         self.store = store
         # top-level 'commited' binding
         self._val = NoValue
+        # domain
+        self.dom = None
         # when updated in a 'transaction', keep track
         # of our initial value (for abort cases)
         self.previous = None
@@ -93,9 +95,11 @@
     #---- Concurrent public ops --------------------------
 
     def is_bound(self):
-        self.mutex.acquire()
-        res = self._is_bound()
-        self.mutex.release()
+        try:
+            self.mutex.acquire()
+            res = self._is_bound()
+        finally:
+            self.mutex.release()
         return res
 
     # should be used by threads that want to block on



More information about the Pypy-commit mailing list