[Python-checkins] r55321 - sandbox/trunk/abc/abc.py sandbox/trunk/abc/test_abc.py
guido.van.rossum
python-checkins at python.org
Mon May 14 23:04:01 CEST 2007
Author: guido.van.rossum
Date: Mon May 14 23:03:55 2007
New Revision: 55321
Modified:
sandbox/trunk/abc/abc.py
sandbox/trunk/abc/test_abc.py
Log:
Add overloading example that works across ABC registrations.
Add mutable sequence and mapping.
Register some built-in types.
Modified: sandbox/trunk/abc/abc.py
==============================================================================
--- sandbox/trunk/abc/abc.py (original)
+++ sandbox/trunk/abc/abc.py Mon May 14 23:03:55 2007
@@ -9,6 +9,7 @@
__author__ = "Guido van Rossum <guido at python.org>"
import sys
+import inspect
import itertools
@@ -149,7 +150,7 @@
elif subclass in cls.__abc_negative_cache__:
return False
# Check if it's a direct subclass
- if cls in subclass.mro():
+ if cls in subclass.__mro__:
cls.__abc_cache__.add(subclass)
return True
# Check if it's a subclass of a registered class (recursive)
@@ -398,15 +399,9 @@
### MAPPINGS ###
+# XXX Get rid of _BasicMapping and view types
-class BasicMapping(Container):
-
- """A basic mapping has __getitem__(), __contains__() and get().
-
- The idea is that you only need to override __getitem__().
-
- Other dict methods are not supported.
- """
+class _BasicMapping(Container, Iterable):
@abstractmethod
def __getitem__(self, key):
@@ -425,9 +420,6 @@
except KeyError:
return False
-
-class IterableMapping(BasicMapping, Iterable):
-
def keys(self):
return KeysView(self)
@@ -484,7 +476,7 @@
yield self._mapping[key]
-class Mapping(IterableMapping, Sized):
+class Mapping(_BasicMapping, Sized):
def keys(self):
return KeysView(self)
@@ -555,6 +547,59 @@
return False
+class MutableMapping(Mapping):
+
+ @abstractmethod
+ def __setitem__(self, key):
+ raise NotImplementedError
+
+ @abstractmethod
+ def __delitem__(self, key):
+ raise NotImplementedError
+
+ __marker = object()
+
+ def pop(self, key, default=__marker):
+ try:
+ value = self[key]
+ except KeyError:
+ if default is self.__marker:
+ raise
+ return default
+ else:
+ del self[key]
+ return value
+
+ def popitem(self):
+ try:
+ key = next(iter(self))
+ except StopIteration:
+ raise KeyError
+ value = self[key]
+ del self[key]
+ return key, value
+
+ def clear(self):
+ try:
+ while True:
+ self.popitem()
+ except KeyError:
+ pass
+
+ def update(self, other=(), **kwds):
+ if isinstance(other, Mapping):
+ for key in other:
+ self[key] = other[key]
+ elif hasattr(other, "keys"):
+ for key in other.keys():
+ self[key] = other[key]
+ else:
+ for key, value in other:
+ self[key] = value
+ for key, value in kwds.items():
+ self[key] = value
+
+
### SEQUENCES ###
@@ -708,8 +753,74 @@
return len(self) <= len(other)
+class MutableSequence(Sequence):
+
+ @abstractmethod
+ def __setitem__(self, i, value):
+ raise NotImplementedError
+
+ @abstractmethod
+ def __delitem__(self, i, value):
+ raise NotImplementedError
+
+ @abstractmethod
+ def insert(self, i, value):
+ raise NotImplementedError
+
+ def append(self, value):
+ self.insert(len(self), value)
+
+ def reverse(self):
+ n = len(self)
+ for i in range(n//2):
+ j = n-i-1
+ self[i], self[j] = self[j], self[i]
+
+ def extend(self, it):
+ for x in it:
+ self.append(x)
+
+ def pop(self, i=None):
+ if i is None:
+ i = len(self) - 1
+ value = self[i]
+ del self[i]
+ return value
+
+ def remove(self, value):
+ for i in range(len(self)):
+ if self[i] == value:
+ del self[i]
+ return
+ raise ValueError
+
+
+
+### PRE-DEFINED REGISTRATIONS ###
+
+Hashable.register(int)
+Hashable.register(float)
+Hashable.register(complex)
+Hashable.register(basestring)
+Hashable.register(tuple)
+Hashable.register(frozenset)
+Hashable.register(type)
+
+Set.register(frozenset)
+MutableSet.register(set)
+
+MutableMapping.register(dict)
+
+Sequence.register(tuple)
+Sequence.register(basestring)
+MutableSequence.register(list)
+MutableSequence.register(bytes)
+
+
### ADAPTERS ###
+# This is just an example, not something to go into the stdlib
+
class AdaptToSequence(Sequence):
@@ -762,3 +873,130 @@
def __len__(self):
return len(self.adaptee)
+
+
+### OVERLOADING ###
+
+# This is a modest alternative proposal to PEP 3124. It uses
+# issubclass() exclusively meaning that any issubclass() overloading
+# automatically works. If accepted it probably ought to go into a
+# separate module (overloading.py?) as it has nothing to do directly
+# with ABCs. The code here is an evolution from my earlier attempt in
+# sandbox/overload/overloading.py.
+
+
+class overloadable:
+
+ """An implementation of overloadable functions.
+
+ Usage example:
+
+ @overloadable
+ def flatten(x):
+ yield x
+
+ @flatten.overload
+ def _(it: Iterable):
+ for x in it:
+ yield x
+
+ @flatten.overload
+ def _(x: basestring):
+ yield x
+
+ """
+
+ def __init__(self, default_func):
+ # Decorator to declare new overloaded function.
+ self.registry = {}
+ self.cache = {}
+ self.default_func = default_func
+
+ def __get__(self, obj, cls=None):
+ if obj is None:
+ return self
+ return new.instancemethod(self, obj)
+
+ def overload(self, func):
+ """Decorator to overload a function using its argument annotations."""
+ self.register_func(self.extract_types(func), func)
+ if func.__name__ == self.default_func.__name__:
+ return self
+ else:
+ return func
+
+ def extract_types(self, func):
+ """Helper to extract argument annotations as a tuple of types."""
+ args, varargs, varkw, defaults, kwonlyargs, kwdefaults, annotations = \
+ inspect.getfullargspec(func)
+ return tuple(annotations.get(arg, object) for arg in args)
+
+ def register_func(self, types, func):
+ """Helper to register an implementation."""
+ self.registry[types] = func
+ self.cache = {} # Clear the cache (later we might optimize this).
+
+ def __call__(self, *args):
+ """Call the overloaded function."""
+ types = tuple(arg.__class__ for arg in args)
+ funcs = self.cache.get(types)
+ if funcs is None:
+ self.cache[types] = funcs = list(self.find_funcs(types))
+ return funcs[0](*args)
+
+ def find_funcs(self, types):
+ """Yield the appropriate overloaded functions, in order."""
+ func = self.registry.get(types)
+ if func is not None:
+ # Easy case -- direct hit in registry.
+ yield func
+ return
+
+ candidates = [cand
+ for cand in self.registry
+ if self.implies(types, cand)]
+
+ if not candidates:
+ # Easy case -- return the default function
+ yield self.default_func
+ return
+
+ if len(candidates) == 1:
+ # Easy case -- return this and the default function
+ yield self.registry[candidates[0]]
+ yield self.default_func
+ return
+
+## # Perhaps all candidates have the same implementation?
+## # XXX What do we care?
+## funcs = set(self.registry[cand] for cand in candidates)
+## if len(funcs) == 1:
+## yield funcs.pop()
+## yield self.default_func
+## return
+
+ candidates.sort(self.comparator) # Sort on a partial ordering!
+ while candidates:
+ cand = candidates.pop(0)
+ if all(self.implies(cand, c) for c in candidates):
+ yield self.registry[cand]
+ else:
+ yield self.raise_ambiguity
+ break
+ else:
+ yield self.default_func
+
+ def comparator(self, xs, ys):
+ return self.implies(ys, xs) - self.implies(xs, ys)
+
+ def implies(self, xs, ys):
+ return len(xs) == len(ys) and all(issubclass(x, y)
+ for x, y in zip(xs, ys))
+
+ def raise_ambiguity(self, *args):
+ # XXX Should be more specific
+ raise TypeError("ambiguous signature of overloadable function")
+
+ def raise_exhausted(self, *args):
+ # XXX Should be more specific
+ raise TypeError("no remaining candidates for overloadable function")
Modified: sandbox/trunk/abc/test_abc.py
==============================================================================
--- sandbox/trunk/abc/test_abc.py (original)
+++ sandbox/trunk/abc/test_abc.py Mon May 14 23:03:55 2007
@@ -76,6 +76,40 @@
self.assertEqual(42 in a, False)
self.assertEqual(len(a), 3)
+ def test_overloading(self):
+ # Basic 'flatten' example
+ @abc.overloadable
+ def flatten(x):
+ yield x
+ @flatten.overload
+ def _(x: abc.Iterable):
+ for a in x:
+ for b in flatten(a):
+ yield b
+ @flatten.overload
+ def _(x: basestring):
+ yield x
+ self.assertEqual(list(flatten([1, 2, 3])), [1, 2, 3])
+ self.assertEqual(list(flatten([1,[2],3])), [1, 2, 3])
+ self.assertEqual(list(flatten([1,[2,3]])), [1, 2, 3])
+ self.assertEqual(list(flatten([1,[2,3]])), [1, 2, 3])
+ self.assertEqual(list(flatten([1,"abc",3])), [1, "abc", 3])
+
+ # Add 2-arg version
+ @flatten.overload
+ def _(t: type, x):
+ return t(flatten(x))
+ self.assertEqual(flatten(tuple, [1, 2, 3]), (1, 2, 3))
+ self.assertEqual(flatten(tuple, [1,[2],3]), (1, 2, 3))
+ self.assertEqual(flatten(tuple, [1,"abc",3]), (1, "abc", 3))
+
+ # Change an overload
+ @flatten.overload
+ def flatten(x: basestring):
+ for c in x:
+ yield c
+ self.assertEqual(list(flatten([1, "abc", 3])), [1, "a", "b", "c", 3])
+
if __name__ == "__main__":
unittest.main()
More information about the Python-checkins
mailing list