[Python-checkins] r43723 - sandbox/trunk/overload sandbox/trunk/overload/overloading.py sandbox/trunk/overload/test_overloading.py sandbox/trunk/overload/time_overloading.py
guido.van.rossum
python-checkins at python.org
Fri Apr 7 23:33:23 CEST 2006
Author: guido.van.rossum
Date: Fri Apr 7 23:33:23 2006
New Revision: 43723
Added:
sandbox/trunk/overload/
sandbox/trunk/overload/overloading.py (contents, props changed)
sandbox/trunk/overload/test_overloading.py (contents, props changed)
sandbox/trunk/overload/time_overloading.py (contents, props changed)
Log:
Snapshot of my function overloading implementation. See python-3000 list.
Added: sandbox/trunk/overload/overloading.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/overload/overloading.py Fri Apr 7 23:33:23 2006
@@ -0,0 +1,129 @@
+#!/usr/bin/env python2.5
+
+# overloading.py
+
+"""Here's another executable email.
+
+This is an implementation of (dynamically, or run-time) overloaded
+functions, formerly known as generic functions or multi-methods.
+
+I actually blogged on Artima about multi-methods about a year ago; but
+at the time I hadn't figured out the trick of explicitly declaring the
+MM before registering the implementations (Phillip helpfully pointed
+that out in a comment); also, I was using a global registry then.
+
+This version is an improvement over my earlier attempts (both last
+year and at the start of this thread) because it supports subclasses
+in call signatures. If an implementation is registered for a
+signature (T1, T2), then a call with a signature (S1, S2) is
+acceptable, assuming that S1 is a subclass of T1, S2 a subclass of T2,
+and there is no ambiguity in the match (see below).
+
+I came up with an algorithm for doing this that may or may not
+resemble the one in PEAK's RuleDispatch. I kind of doubt that it's
+all that similar because RuleDispatch supports arbitrary predicates.
+In contrast, I'm just using the argument types for dispatch, similar
+to (compile-time) overloaded functions in C++ and methods in Java. I
+do use a concept that I overheard Phillip mention: if there are
+multiple matches and one of those doesn't *dominate* the others, the
+match is deemed ambiguous and an exception is raised. I added one
+refinement of my own: if, after removing the dominated matches, there
+are still multiple matches left, but they all map to the same
+function, then the match is not deemed ambiguous and that function is
+used. Read the method find_func() below for details.
+
+The example is a bit lame; it's not a very good pretty-printer and it
+only dispatches on a single argument; but it does exercise the weeding
+out of dominant matches. I'll try to post a complete unit test suite
+later.
+
+Python 2.5 is required due to the use of predicates any() and all().
+
+"""
+
+# Make the environment more like Python 3.0
+__metaclass__ = type
+from itertools import izip as zip
+
+
+class overloaded:
+ # An implementation of overloaded functions.
+
+ def __init__(self, default_func):
+ # Decorator to declare new overloaded function.
+ self.registry = {}
+ self.cache = {}
+ self.default_func = default_func
+
+ def register(self, *types):
+ # Decorator to register an implementation for a specific set of types.
+ # .register(t1, t2)(f) is equivalent to .register_func((t1, t2), f).
+ def helper(func):
+ self.register_func(types, func)
+ return func
+ return helper
+
+ def register_func(self, types, func):
+ # Helper to register an implementation.
+ self.registry[tuple(types)] = func
+ self.cache = {} # Clear the cache (later we can optimize this).
+
+ def __call__(self, *args):
+ # Call the overloaded function.
+ types = tuple(map(type, args))
+ func = self.cache.get(types)
+ if func is None:
+ self.cache[types] = func = self.find_func(types)
+ return func(*args)
+
+ def find_func(self, types):
+ # Find the appropriate overloaded function; don't call it.
+ # NB. This won't work for old-style classes or classes without __mro__.
+ func = self.registry.get(types)
+ if func is not None:
+ # Easy case -- direct hit in registry.
+ return func
+ # I can't help myself -- this is going to be intense functional code.
+ # Find all possible candidate signatures.
+ mros = tuple(t.__mro__ for t in types)
+ n = len(mros)
+ candidates = [sig for sig in self.registry
+ if len(sig) == n and
+ all(t in mro for t, mro in zip(sig, mros))]
+ if not candidates:
+ # No match at all -- use the default function.
+ return self.default_func
+ if len(candidates) == 1:
+ # Unique match -- that's an easy case.
+ return self.registry[candidates[0]]
+ # More than one match -- weed out the subordinate ones.
+ def dominates(dom, sub,
+ orders=tuple(dict((t, i) for i, t in enumerate(mro))
+ for mro in mros)):
+ # Predicate to decide whether dom strictly dominates sub.
+ # Strict domination is defined as domination without equality.
+ # The arguments dom and sub are type tuples of equal length.
+ # The orders argument is a precomputed auxiliary data structure
+ # giving dicts of ordering information corresponding to the
+ # positions in the type tuples.
+ # A type d dominates a type s iff order[d] <= order[s].
+ # A type tuple (d1, d2, ...) dominates a type tuple of equal length
+ # (s1, s2, ...) iff d1 dominates s1, d2 dominates s2, etc.
+ if dom is sub:
+ return False
+ return all(order[d] <= order[s]
+ for d, s, order in zip(dom, sub, orders))
+ # I suppose I could inline dominates() but it wouldn't get any clearer.
+ candidates = [cand
+ for cand in candidates
+ if not any(dominates(dom, cand) for dom in candidates)]
+ if len(candidates) == 1:
+ # There's exactly one candidate left.
+ return self.registry[candidates[0]]
+ # Perhaps these multiple candidates all have the same implementation?
+ funcs = set(self.registry[cand] for cand in candidates)
+ if len(funcs) == 1:
+ return funcs.pop()
+ # No, the situation is irreducibly ambiguous.
+ raise TypeError("ambigous call; types=%r; candidates=%r" %
+ (types, candidates))
Added: sandbox/trunk/overload/test_overloading.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/overload/test_overloading.py Fri Apr 7 23:33:23 2006
@@ -0,0 +1,111 @@
+#!/usr/bin/env python2.5
+
+"""Unit tests for overloading.py."""
+
+import timeit
+import unittest
+
+from overloading import overloaded
+
+__metaclass__ = type # New-style classes by default
+
+# Helper classes
+class List(list):
+ pass
+class SubList(List):
+ pass
+
+# Sample test data
+test_data = (
+ "this is a string", [1, 2, 3, 4], ("more tuples",
+ 1.0, 2.3, 4.5), "this is yet another string", (99,)
+ )
+
+class OverloadingTests(unittest.TestCase):
+
+ def test_1(self):
+ @overloaded
+ def pprint(obj):
+ return repr(obj)
+ @pprint.register(List)
+ @pprint.register(list)
+ def pprint_list(obj):
+ if not obj:
+ return "[]"
+ s = "["
+ for item in obj:
+ s += pprint(item).replace("\n", "\n ") + ",\n "
+ return s[:-3] + "]"
+ @pprint.register(tuple)
+ def pprint_tuple(obj):
+ if not obj:
+ return "()"
+ s = "("
+ for item in obj:
+ s += pprint(item).replace("\n", "\n ") + ",\n "
+ if len(obj) == 1:
+ return s[:-2] + ")"
+ return s[:-3] + ")"
+ @pprint.register(dict)
+ def pprint_dict(obj):
+ if not obj:
+ return "{}"
+ s = "{"
+ for key, value in obj.iteritems():
+ s += (pprint(key).replace("\n", "\n ") + ": " +
+ pprint(value).replace("\n", "\n ") + ",\n ")
+ return s[:-3] + "}"
+ @pprint.register(set)
+ def pprint_set(obj):
+ if not obj:
+ return "{/}"
+ s = "{"
+ for item in obj:
+ s += pprint(item).replace("\n", "\n ") + ",\n "
+ return s[:-3] + "}"
+ # This is not a very good test
+ a = pprint(test_data)
+ b = pprint(List(test_data))
+ c = pprint(SubList(test_data))
+ self.assertEqual(a[1:-1], b[1:-1])
+ self.assertEqual(b, c)
+
+ def test_2(self):
+ class A: pass
+ class B: pass
+ class C(A, B): pass
+ def defaultfoo(x, y): return "default"
+ @overloaded
+ def foo(x, y): return defaultfoo(x, y)
+ @foo.register(A, B)
+ def fooAB(x, y): return "AB"
+ @foo.register(A, C)
+ def fooAC(A, C): return "AC"
+ @foo.register(B, A)
+ def fooBA(x, y): return "BA"
+ @foo.register(C, B)
+ def fooCB(x, y): return "CB"
+
+ self.assertEqual(foo(A(), A()), "default")
+ self.assertEqual(foo(A(), B()), "AB")
+ self.assertEqual(foo(A(), C()), "AC")
+ self.assertEqual(foo(A(), 123), "default")
+
+ self.assertEqual(foo(B(), A()), "BA")
+ self.assertEqual(foo(B(), B()), "default")
+ self.assertEqual(foo(B(), C()), "BA")
+ self.assertEqual(foo(B(), 123), "default")
+
+ self.assertEqual(foo(C(), A()), "BA")
+ self.assertEqual(foo(C(), B()), "CB")
+ self.assertRaises(TypeError, foo, C(), C())
+ self.assertEqual(foo(C(), 123), "default")
+
+ self.assertEqual(foo("x", A()), "default")
+ self.assertEqual(foo("x", B()), "default")
+ self.assertEqual(foo("x", C()), "default")
+ self.assertEqual(foo("x", 123), "default")
+
+
+if __name__ == "__main__":
+ unittest.main()
Added: sandbox/trunk/overload/time_overloading.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/overload/time_overloading.py Fri Apr 7 23:33:23 2006
@@ -0,0 +1,71 @@
+#!/usr/bin/env python2.5
+
+"""Unit tests for overloading.py."""
+
+import timeit
+
+from overloading import overloaded
+
+__metaclass__ = type # New-style classes by default
+
+class A: pass
+class B: pass
+class C(A, B): pass
+def defaultfoo(x, y): pass
+ at overloaded
+def foo(x, y): defaultfoo(x, y)
+ at foo.register(A, B)
+def fooAB(x, y): pass
+ at foo.register(A, C)
+def fooAC(A, C): pass
+ at foo.register(B, A)
+def fooBA(x, y): pass
+ at foo.register(C, B)
+def fooCB(x, y): pass
+foo(C(), B())
+foo(A(), C())
+try:
+ foo(C(), C())
+except TypeError:
+ pass
+else:
+ assert False
+timeit.test_func = foo
+timeit.test_arg1 = C()
+timeit.test_arg2 = B()
+t = timeit.Timer("test_func(test_arg1, test_arg2)")
+r = t.repeat(3, 10000)
+print "foo(C(), C()) %.3f" % min(r)
+def bar(x, y):
+ if isinstance(x, C):
+ if isinstance(y, B):
+ return fooCB(x, y)
+ if isinstance(x, B):
+ if isinstance(y, A):
+ return fooBA(x, y)
+ if isinstance(x, A):
+ if isinstance(y, C):
+ return fooAC(x, y)
+ if isinstance(y, B):
+ return fooAB(x, y)
+ return defaultfoo(x, y)
+timeit.test_func = bar
+r = t.repeat(3, 10000)
+print "bar(C(), C()) %.3f" % min(r)
+timeit.test_arg1 = A()
+timeit.test_arg2 = A()
+r = t.repeat(3, 10000)
+print "bar(A(), A()) %.3f" % min(r)
+class Bar:
+ def __init__(self):
+ self.cache = {}
+ for t1 in A, B, C:
+ for t2 in A, B, C:
+ self.cache[t1, t2] = defaultfoo
+ def __call__(self, x, y):
+ return defaultfoo(x, y)
+## t = tuple(map(type, args))
+## return self.cache[t](*args)
+timeit.test_func = Bar()
+r = t.repeat(3, 10000)
+print "Bar()(A(), A()) %.3f" % min(r)
More information about the Python-checkins
mailing list