[Python-checkins] r43723 - sandbox/trunk/overload sandbox/trunk/overload/overloading.py sandbox/trunk/overload/test_overloading.py sandbox/trunk/overload/time_overloading.py

guido.van.rossum python-checkins at python.org
Fri Apr 7 23:33:23 CEST 2006


Author: guido.van.rossum
Date: Fri Apr  7 23:33:23 2006
New Revision: 43723

Added:
   sandbox/trunk/overload/
   sandbox/trunk/overload/overloading.py   (contents, props changed)
   sandbox/trunk/overload/test_overloading.py   (contents, props changed)
   sandbox/trunk/overload/time_overloading.py   (contents, props changed)
Log:
Snapshot of my function overloading implementation.  See python-3000 list.


Added: sandbox/trunk/overload/overloading.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/overload/overloading.py	Fri Apr  7 23:33:23 2006
@@ -0,0 +1,129 @@
+#!/usr/bin/env python2.5
+
+# overloading.py
+
+"""Here's another executable email.
+
+This is an implementation of (dynamically, or run-time) overloaded
+functions, formerly known as generic functions or multi-methods.
+
+I actually blogged on Artima about multi-methods about a year ago; but
+at the time I hadn't figured out the trick of explicitly declaring the
+MM before registering the implementations (Phillip helpfully pointed
+that out in a comment); also, I was using a global registry then.
+
+This version is an improvement over my earlier attempts (both last
+year and at the start of this thread) because it supports subclasses
+in call signatures.  If an implementation is registered for a
+signature (T1, T2), then a call with a signature (S1, S2) is
+acceptable, assuming that S1 is a subclass of T1, S2 a subclass of T2,
+and there is no ambiguity in the match (see below).
+
+I came up with an algorithm for doing this that may or may not
+resemble the one in PEAK's RuleDispatch.  I kind of doubt that it's
+all that similar because RuleDispatch supports arbitrary predicates.
+In contrast, I'm just using the argument types for dispatch, similar
+to (compile-time) overloaded functions in C++ and methods in Java.  I
+do use a concept that I overheard Phillip mention: if there are
+multiple matches and one of those doesn't *dominate* the others, the
+match is deemed ambiguous and an exception is raised.  I added one
+refinement of my own: if, after removing the dominated matches, there
+are still multiple matches left, but they all map to the same
+function, then the match is not deemed ambiguous and that function is
+used.  Read the method find_func() below for details.
+
+The example is a bit lame; it's not a very good pretty-printer and it
+only dispatches on a single argument; but it does exercise the weeding
+out of dominant matches.  I'll try to post a complete unit test suite
+later.
+
+Python 2.5 is required due to the use of predicates any() and all().
+
+"""
+
+# Make the environment more like Python 3.0
+__metaclass__ = type
+from itertools import izip as zip
+
+
+class overloaded:
+    # An implementation of overloaded functions.
+
+    def __init__(self, default_func):
+        # Decorator to declare new overloaded function.
+        self.registry = {}
+        self.cache = {}
+        self.default_func = default_func
+
+    def register(self, *types):
+        # Decorator to register an implementation for a specific set of types.
+        # .register(t1, t2)(f) is equivalent to .register_func((t1, t2), f).
+        def helper(func):
+            self.register_func(types, func)
+            return func
+        return helper
+
+    def register_func(self, types, func):
+        # Helper to register an implementation.
+        self.registry[tuple(types)] = func
+        self.cache = {} # Clear the cache (later we can optimize this).
+
+    def __call__(self, *args):
+        # Call the overloaded function.
+        types = tuple(map(type, args))
+        func = self.cache.get(types)
+        if func is None:
+            self.cache[types] = func = self.find_func(types)
+        return func(*args)
+
+    def find_func(self, types):
+        # Find the appropriate overloaded function; don't call it.
+        # NB. This won't work for old-style classes or classes without __mro__.
+        func = self.registry.get(types)
+        if func is not None:
+            # Easy case -- direct hit in registry.
+            return func
+        # I can't help myself -- this is going to be intense functional code.
+        # Find all possible candidate signatures.
+        mros = tuple(t.__mro__ for t in types)
+        n = len(mros)
+        candidates = [sig for sig in self.registry
+                      if len(sig) == n and
+                         all(t in mro for t, mro in zip(sig, mros))]
+        if not candidates:
+            # No match at all -- use the default function.
+            return self.default_func
+        if len(candidates) == 1:
+            # Unique match -- that's an easy case.
+            return self.registry[candidates[0]]
+        # More than one match -- weed out the subordinate ones.
+        def dominates(dom, sub,
+                      orders=tuple(dict((t, i) for i, t in enumerate(mro))
+                                   for mro in mros)):
+            # Predicate to decide whether dom strictly dominates sub.
+            # Strict domination is defined as domination without equality.
+            # The arguments dom and sub are type tuples of equal length.
+            # The orders argument is a precomputed auxiliary data structure
+            # giving dicts of ordering information corresponding to the
+            # positions in the type tuples.
+            # A type d dominates a type s iff order[d] <= order[s].
+            # A type tuple (d1, d2, ...) dominates a type tuple of equal length
+            # (s1, s2, ...) iff d1 dominates s1, d2 dominates s2, etc.
+            if dom is sub:
+                return False
+            return all(order[d] <= order[s]
+                       for d, s, order in zip(dom, sub, orders))
+        # I suppose I could inline dominates() but it wouldn't get any clearer.
+        candidates = [cand
+                      for cand in candidates
+                      if not any(dominates(dom, cand) for dom in candidates)]
+        if len(candidates) == 1:
+            # There's exactly one candidate left.
+            return self.registry[candidates[0]]
+        # Perhaps these multiple candidates all have the same implementation?
+        funcs = set(self.registry[cand] for cand in candidates)
+        if len(funcs) == 1:
+            return funcs.pop()
+        # No, the situation is irreducibly ambiguous.
+        raise TypeError("ambigous call; types=%r; candidates=%r" %
+                        (types, candidates))

Added: sandbox/trunk/overload/test_overloading.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/overload/test_overloading.py	Fri Apr  7 23:33:23 2006
@@ -0,0 +1,111 @@
+#!/usr/bin/env python2.5
+
+"""Unit tests for overloading.py."""
+
+import timeit
+import unittest
+
+from overloading import overloaded
+
+__metaclass__ = type # New-style classes by default
+
+# Helper classes
+class List(list):
+    pass
+class SubList(List):
+    pass
+
+# Sample test data
+test_data = (
+    "this is a string", [1, 2, 3, 4], ("more tuples",
+    1.0, 2.3, 4.5), "this is yet another string", (99,)
+    )
+
+class OverloadingTests(unittest.TestCase):
+
+    def test_1(self):
+        @overloaded
+        def pprint(obj):
+            return repr(obj)
+        @pprint.register(List)
+        @pprint.register(list)
+        def pprint_list(obj):
+            if not obj:
+                return "[]"
+            s = "["
+            for item in obj:
+                s += pprint(item).replace("\n", "\n ") + ",\n "
+            return s[:-3] + "]"
+        @pprint.register(tuple)
+        def pprint_tuple(obj):
+            if not obj:
+                return "()"
+            s = "("
+            for item in obj:
+                s += pprint(item).replace("\n", "\n ") + ",\n "
+            if len(obj) == 1:
+                return s[:-2] + ")"
+            return s[:-3] + ")"
+        @pprint.register(dict)
+        def pprint_dict(obj):
+            if not obj:
+                return "{}"
+            s = "{"
+            for key, value in obj.iteritems():
+                s += (pprint(key).replace("\n", "\n ") + ": " +
+                      pprint(value).replace("\n", "\n ") + ",\n ")
+            return s[:-3] + "}"
+        @pprint.register(set)
+        def pprint_set(obj):
+            if not obj:
+                return "{/}"
+            s = "{"
+            for item in obj:
+                s += pprint(item).replace("\n", "\n ") + ",\n "
+            return s[:-3] + "}"
+        # This is not a very good test
+        a = pprint(test_data)
+        b = pprint(List(test_data))
+        c = pprint(SubList(test_data))
+        self.assertEqual(a[1:-1], b[1:-1])
+        self.assertEqual(b, c)
+
+    def test_2(self):
+        class A: pass
+        class B: pass
+        class C(A, B): pass
+        def defaultfoo(x, y): return "default"
+        @overloaded
+        def foo(x, y): return defaultfoo(x, y)
+        @foo.register(A, B)
+        def fooAB(x, y): return "AB"
+        @foo.register(A, C)
+        def fooAC(A, C): return "AC"
+        @foo.register(B, A)
+        def fooBA(x, y): return "BA"
+        @foo.register(C, B)
+        def fooCB(x, y): return "CB"
+
+        self.assertEqual(foo(A(), A()), "default")
+        self.assertEqual(foo(A(), B()), "AB")
+        self.assertEqual(foo(A(), C()), "AC")
+        self.assertEqual(foo(A(), 123), "default")
+
+        self.assertEqual(foo(B(), A()), "BA")
+        self.assertEqual(foo(B(), B()), "default")
+        self.assertEqual(foo(B(), C()), "BA")
+        self.assertEqual(foo(B(), 123), "default")
+
+        self.assertEqual(foo(C(), A()), "BA")
+        self.assertEqual(foo(C(), B()), "CB")
+        self.assertRaises(TypeError, foo, C(), C())
+        self.assertEqual(foo(C(), 123), "default")
+
+        self.assertEqual(foo("x", A()), "default")
+        self.assertEqual(foo("x", B()), "default")
+        self.assertEqual(foo("x", C()), "default")
+        self.assertEqual(foo("x", 123), "default")
+
+
+if __name__ == "__main__":
+    unittest.main()

Added: sandbox/trunk/overload/time_overloading.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/overload/time_overloading.py	Fri Apr  7 23:33:23 2006
@@ -0,0 +1,71 @@
+#!/usr/bin/env python2.5
+
+"""Unit tests for overloading.py."""
+
+import timeit
+
+from overloading import overloaded
+
+__metaclass__ = type # New-style classes by default
+
+class A: pass
+class B: pass
+class C(A, B): pass
+def defaultfoo(x, y): pass
+ at overloaded
+def foo(x, y): defaultfoo(x, y)
+ at foo.register(A, B)
+def fooAB(x, y): pass
+ at foo.register(A, C)
+def fooAC(A, C): pass
+ at foo.register(B, A)
+def fooBA(x, y): pass
+ at foo.register(C, B)
+def fooCB(x, y): pass
+foo(C(), B())
+foo(A(), C())
+try:
+    foo(C(), C())
+except TypeError:
+    pass
+else:
+    assert False
+timeit.test_func = foo
+timeit.test_arg1 = C()
+timeit.test_arg2 = B()
+t = timeit.Timer("test_func(test_arg1, test_arg2)")
+r = t.repeat(3, 10000)
+print "foo(C(), C())   %.3f" % min(r)
+def bar(x, y):
+    if isinstance(x, C):
+        if isinstance(y, B):
+            return fooCB(x, y)
+    if isinstance(x, B):
+        if isinstance(y, A):
+            return fooBA(x, y)
+    if isinstance(x, A):
+        if isinstance(y, C):
+            return fooAC(x, y)
+        if isinstance(y, B):
+            return fooAB(x, y)
+    return defaultfoo(x, y)
+timeit.test_func = bar
+r = t.repeat(3, 10000)
+print "bar(C(), C())   %.3f" % min(r)
+timeit.test_arg1 = A()
+timeit.test_arg2 = A()
+r = t.repeat(3, 10000)
+print "bar(A(), A())   %.3f" % min(r)
+class Bar:
+    def __init__(self):
+        self.cache = {}
+        for t1 in A, B, C:
+            for t2 in A, B, C:
+                self.cache[t1, t2] = defaultfoo
+    def __call__(self, x, y):
+        return defaultfoo(x, y)
+##        t = tuple(map(type, args))
+##        return self.cache[t](*args)
+timeit.test_func = Bar()
+r = t.repeat(3, 10000)
+print "Bar()(A(), A()) %.3f" % min(r)


More information about the Python-checkins mailing list