[Python-Dev] multiple dispatch (ii)

Samuele Pedroni pedroni@inf.ethz.ch
Wed, 14 Aug 2002 18:17:14 +0200


Here is my old code,
is kind of a alpha quality prototype code,
no syntax sugar, no integration, pure python.

The "_redispatch" mechanism is the moral
equivalent of

class A:
  def meth(self): ...

class B(A):
  def meth(self):
     A.meth(self)

it is used both for call-next-method functionality
(that means super for multiple dispatch)
and to solve ambiguities.

(this is pre 2.2 stuff, nowadays
the mro of the actual argument type can be used
to solve ambiguities (like CLOS and Dylan do), if you add
interfaces/protocols to the picture you should
decide how to merge them in the mro, if the case)

[it uses memoization and so you can't fiddle
with __bases__]

#test_mdisp.py:

print "** mdisp test"
import mdisp

class Panel: pass

class PadPanel(Panel): pass

class Specific: pass

present = mdisp.Generic()

panel = PadPanel()
spec = Specific()

def pan(p,o):
    print "generic panel present"

def pad(p,o):
    print "pad panel present"

def speci(p,o):
    print "generic panel <specific> present"

def padspeci(p,o):
    print "pad panel <specific> present"

present.add_method((Panel,mdisp.Any),pan)

present(panel,spec)

present.add_method((Panel,Specific),speci)

present(panel,spec)

present.add_method((PadPanel,mdisp.Any),pad)

try:
    present(panel,spec)
except mdisp.AmbiguousMethodError:
    print "ambiguity"

print "_redispatch = (None,Any)",

present(panel,spec,_redispatch=(None,mdisp.Any))

present.add_method((PadPanel,Specific),padspeci)

present(panel,spec)

print "* again... panel:obj tierule"

present=mdisp.Generic("panel:obj")

present.add_method((Panel,mdisp.Any),pan)

present(panel,spec)

present.add_method((Panel,Specific),speci)

present(panel,spec)

present.add_method((PadPanel,mdisp.Any),pad)

try:
    present(panel,spec)
except mdisp.AmbiguousMethodError:
    print "ambiguity"

present.add_method((PadPanel,Specific),padspeci)

present(panel,spec)

OUTPUT
** mdisp test
generic panel present
generic panel <specific> present
ambiguity
_redispatch = (None,Any) pad panel present
pad panel <specific> present
* again... panel:obj tierule
generic panel present
generic panel <specific> present
pad panel present
pad panel <specific> present

#actual mdisp.py:

import types
import re

def class_of(obj):
    if type(obj) is types.InstanceType:
        return obj.__class__
    else:
        return type(obj)

NonComparable = None
class Any: pass

def class_le(cl1,cl2):
    if cl1 == cl2: return 1
    if cl2 == Any: return 1
    try:
        cl_lt = issubclass(cl1,cl2)
        cl_gt = issubclass(cl2,cl1)
        if not (cl_lt or cl_gt): return NonComparable
        return cl_lt
    except:
        return NonComparable

def classes_tuple_le(tup1,tup2):
    if len(tup1) != len(tup2): return NonComparable
    tup_le = 0
    tup_gt = 0
    for cl1,cl2 in zip(tup1,tup2):
        cl_le = class_le(cl1,cl2)
        if cl_le == NonComparable:
            return NonComparable
        if cl_le:
            tup_le |= 1
        else:
            tup_gt |= 1
        if tup_le and tup_gt: return NonComparable
    return tup_le

def classes_tuple_le_ex(tup1,tup2, tierule = None):
    if len(tup1) != len(tup2): return NonComparable
    if not tierule: tierule = (len(tup1),)
    last = 0
    for upto in tierule:
        sl1 = tup1[last:upto]
        sl2 = tup2[last:upto]
        last = upto
        if sl1 == sl2: continue
        if len(sl1) == 1:
            return class_le(sl1[0],sl2[0])
        sl_le = 0
        sl_gt = 0
        for cl1,cl2 in zip(sl1,sl2):
            cl_le = class_le(cl1,cl2)
            if cl_le == NonComparable:
                return NonComparable
            if cl_le:
                sl_le |= 1
            else:
                sl_gt |= 1
            if sl_le and sl_gt: return NonComparable
        return sl_le
    return 1

_id_regexp = re.compile("\w+")

def build_tierule(patt):
    tierule = []
    last = 0
    for uni in patt.split(':'):
        c = 0
        for arg in uni.split(','):
            if not _id_regexp.match(arg): raise "ValueError","invalid Generic
(tierule) pattern"
            c += 1
        last += c
        tierule.append(last)
    return tierule

def forge_classes_tuple(model,tup):
    return tuple ( map ( lambda (m,cl): m or cl,
                 zip(model,tup)))

class GenericDispatchError(TypeError): pass

class NoApplicableMethodError(GenericDispatchError): pass

class AmbiguousMethodError(GenericDispatchError): pass

class Generic:
    def __init__(self,args=None):
        self.cache = {}
        self.methods = {}
        if args:
            self.args = args
            self.tierule = build_tierule(args)
        else:
            self.args = "???"
            self.tierule = None

    def add_method(self,cltup,func):
        self.methods[cltup] = func
        new_meth = (cltup,func)
        self.cache[cltup] = new_meth
        for d_cltup,(meth_cltup,meth_func) in self.cache.items():
            if classes_tuple_le(d_cltup,cltup):
                le = classes_tuple_le_ex(cltup,meth_cltup,self.tierule)
                if le == NonComparable:
                    del self.cache[d_cltup]
                elif le:
                    self.cache[d_cltup] = new_meth

    def __call__(self,*args,**kw):
        redispatch = kw.get('_redispatch',None)
        d_cltup = map(class_of,args)
        if redispatch:
            d_cltup = forge_classes_tuple(redispatch,d_cltup)
        else:
            d_cltup = tuple(d_cltup)

        if self.cache.has_key(d_cltup):
            return self.cache[d_cltup][1](*args) # 1 retrieves func

        cands = []
        for cltup in self.methods.keys():
            if d_cltup == cltup:
                return self.methods[cltup](*args)
            if classes_tuple_le(d_cltup,cltup): # applicable?
                i = len(cands)
                app = not i
                i -= 1
                while i>=0:
                    cand = cands[i]
                    le = classes_tuple_le_ex(cltup,cand,self.tierule)
                    #print cltup,"<=",cand,"?",le
                    if le == NonComparable:
                        app = 1
                    elif le:
                        if cand != cltup:
                            app = 1
                            #print "remove",cand
                            del cands[i]
                    i -= 1
                if app:
                    cands.append(cltup)
                #print cands
        if len(cands) == 0:
            raise NoApplicableMethodError
        if len(cands)>1:
            raise AmbiguousMethodError
        cltup = cands[0]
        func = self.methods[cltup]
        self.cache[d_cltup] = (cltup,func)
        return func(*args)