[pypy-svn] r56438 - in pypy/dist/pypy/module/itertools: . test

adurdin at codespeak.net adurdin at codespeak.net
Fri Jul 11 11:06:50 CEST 2008


Author: adurdin
Date: Fri Jul 11 11:06:50 2008
New Revision: 56438

Modified:
   pypy/dist/pypy/module/itertools/__init__.py
   pypy/dist/pypy/module/itertools/interp_itertools.py
   pypy/dist/pypy/module/itertools/test/test_itertools.py
Log:
(adurdin, jlg) Implemented interp_itertools.chain


Modified: pypy/dist/pypy/module/itertools/__init__.py
==============================================================================
--- pypy/dist/pypy/module/itertools/__init__.py	(original)
+++ pypy/dist/pypy/module/itertools/__init__.py	Fri Jul 11 11:06:50 2008
@@ -25,6 +25,7 @@
     """
 
     interpleveldefs = {
+        'chain'     : 'interp_itertools.W_Chain',
         'count'     : 'interp_itertools.W_Count',
         'dropwhile' : 'interp_itertools.W_DropWhile',
         'ifilter'   : 'interp_itertools.W_IFilter',

Modified: pypy/dist/pypy/module/itertools/interp_itertools.py
==============================================================================
--- pypy/dist/pypy/module/itertools/interp_itertools.py	(original)
+++ pypy/dist/pypy/module/itertools/interp_itertools.py	Fri Jul 11 11:06:50 2008
@@ -370,3 +370,70 @@
     report may list a name field on every third line).
     """)
 
+
+class W_Chain(Wrappable):
+    def __init__(self, space, args_w):
+        self.space = space
+        iterators_w = []
+        for i, iterable_w in enumerate(args_w):
+            try:
+                iterator_w = space.iter(iterable_w)
+            except OperationError, e:
+                if e.match(self.space, self.space.w_TypeError):
+                    raise OperationError(space.w_TypeError, space.wrap("chain argument #" + str(i + 1) + " must support iteration"))
+                else:
+                    raise
+            else:
+                iterators_w.append(iterator_w)
+        self.iterators_w = iter(iterators_w)
+        self.started = False
+
+    def iter_w(self):
+        return self.space.wrap(self)
+
+    def next_w(self):
+        if not self.started:
+            try:
+                self.w_it = self.iterators_w.next()
+            except StopIteration:
+                raise OperationError(self.space.w_StopIteration, self.space.w_None)
+            else:
+                self.started = True
+
+        while True:
+            try:
+                w_obj = self.space.next(self.w_it)
+            except OperationError, e:
+                if e.match(self.space, self.space.w_StopIteration):
+                    try:
+                        self.w_it = self.iterators_w.next()
+                    except StopIteration:
+                        raise OperationError(self.space.w_StopIteration, self.space.w_None)
+                else:
+                    raise
+            else:
+                break
+        return w_obj
+
+def W_Chain___new__(space, w_subtype, args_w):
+    result = space.allocate_instance(W_Chain, w_subtype)
+    W_Chain.__init__(result, space, args_w)
+    return space.wrap(result)
+
+W_Chain.typedef = TypeDef(
+        'chain',
+        __new__  = interp2app(W_Chain___new__, unwrap_spec=[ObjSpace, W_Root, 'args_w']),
+        __iter__ = interp2app(W_Chain.iter_w, unwrap_spec=['self']),
+        next     = interp2app(W_Chain.next_w, unwrap_spec=['self']),
+        __doc__  = """Make an iterator that returns elements from the first iterable
+    until it is exhausted, then proceeds to the next iterable, until
+    all of the iterables are exhausted. Used for treating consecutive
+    sequences as a single sequence.
+
+    Equivalent to :
+
+    def chain(*iterables):
+        for it in iterables:
+            for element in it:
+                yield element
+    """)

Modified: pypy/dist/pypy/module/itertools/test/test_itertools.py
==============================================================================
--- pypy/dist/pypy/module/itertools/test/test_itertools.py	(original)
+++ pypy/dist/pypy/module/itertools/test/test_itertools.py	Fri Jul 11 11:06:50 2008
@@ -15,6 +15,7 @@
             itertools.ifilter(None, []),
             itertools.ifilterfalse(None, []),
             itertools.islice([], 0),
+            itertools.chain(),
             ]
 
         for it in iterables:
@@ -242,10 +243,44 @@
 
         raises(TypeError, itertools.islice, [], 0, 0, 0, 0)
 
+    def test_chain(self):
+        import itertools
+        
+        it = itertools.chain()
+        raises(StopIteration, it.next)
+        raises(StopIteration, it.next)
+        
+        it = itertools.chain([1, 2, 3])
+        for x in [1, 2, 3]:
+            assert it.next() == x
+        raises(StopIteration, it.next)
+
+        it = itertools.chain([1, 2, 3], [4], [5, 6])
+        for x in [1, 2, 3, 4, 5, 6]:
+            assert it.next() == x
+        raises(StopIteration, it.next)
+
+        it = itertools.chain([], [], [1], [])
+        assert it.next() == 1
+        raises(StopIteration, it.next)
+
+    def test_chain_wrongargs(self):
+        import itertools
+        
+        raises(TypeError, itertools.chain, None)
+        raises(TypeError, itertools.chain, [], None)
+        
+        for x in range(10):
+            args = [()] * x + [None] + [()] * (9 - x)
+            try:
+                itertools.chain(*args)
+            except TypeError, e:
+                assert str(e) == "chain argument #%d must support iteration" % (x + 1)
+
     def test_docstrings(self):
         import itertools
         
-        assert itertools.__doc__ != ""
+        assert itertools.__doc__
         methods = [
             itertools.count,
             itertools.repeat,
@@ -254,10 +289,10 @@
             itertools.ifilter,
             itertools.ifilterfalse,
             itertools.islice,
+            itertools.chain,
             ]
         for method in methods:
-            assert method.__doc__ != ""
-        
+            assert method.__doc__
 
     def test_subclassing(self):
         import itertools
@@ -270,6 +305,7 @@
             (itertools.ifilter, (None, [])),
             (itertools.ifilterfalse, (None, [])),
             (itertools.islice, ([], 0)),
+            (itertools.chain, ()),
             ]
         for cls, args in iterables:
             class A(cls):



More information about the Pypy-commit mailing list