[pypy-commit] pypy default: hopefully make flatiter interoperate nicely with the rest by making it a part

fijal noreply at buildbot.pypy.org
Mon Nov 28 11:20:11 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: 
Changeset: r49886:9c707f4a6aa4
Date: 2011-11-28 12:15 +0200
http://bitbucket.org/pypy/pypy/changeset/9c707f4a6aa4/

Log:	hopefully make flatiter interoperate nicely with the rest by making
	it a part of BaseArray hierarchy

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -178,6 +178,25 @@
     def get_offset(self):
         return self.offset
 
+class OneDimIterator(BaseIterator):
+    def __init__(self, start, step, size):
+        self.offset = start
+        self.step = step
+        self.size = size
+
+    def next(self, shapelen):
+        arr = instantiate(OneDimIterator)
+        arr.size = self.size
+        arr.step = self.step
+        arr.offset = self.offset + self.step
+        return arr
+
+    def done(self):
+        return self.offset >= self.size
+
+    def get_offset(self):
+        return self.offset
+
 class ViewIterator(BaseIterator):
     def __init__(self, arr):
         self.indices = [0] * len(arr.shape)
@@ -1218,18 +1237,39 @@
 )
 
 
-class W_FlatIterator(Wrappable):
-    _immutable_fields_ = ['shapelen', 'arr']
+class W_FlatIterator(ViewArray):
+    signature = signature.BaseSignature()
+    
+    @jit.unroll_safe
+    def __init__(self, arr):
+        size = 1
+        for sh in arr.shape:
+            size *= sh
+        new_sig = signature.Signature.find_sig([
+            W_FlatIterator.signature, arr.signature
+        ])
+        ViewArray.__init__(self, arr, new_sig, [arr.strides[-1]],
+                           [arr.backstrides[-1]], [size])
+        self.shapelen = len(arr.shape)
+        self.arr = arr
+        self.iter = self.start_iter()
 
-    def __init__(self, arr):
-        self.arr = arr.get_concrete()
-        self.iter = arr.start_iter()
-        self.shapelen = len(arr.shape)
+    def start_iter(self, res_shape=None):
+        if res_shape is not None and res_shape != self.shape:
+            return BroadcastIterator(self, res_shape)
+        return OneDimIterator(self.arr.start, self.strides[0],
+                              self.shape[0])
+
+    def find_dtype(self):
+        return self.arr.find_dtype()
+
+    def find_size(self):
+        return self.shape[0]
 
     def descr_next(self, space):
         if self.iter.done():
             raise OperationError(space.w_StopIteration, space.wrap(''))
-        result = self.arr.eval(self.iter)
+        result = self.eval(self.iter)
         self.iter = self.iter.next(self.shapelen)
         return result.wrap(space)
 
@@ -1242,3 +1282,4 @@
     next = interp2app(W_FlatIterator.descr_next),
     __iter__ = interp2app(W_FlatIterator.descr_iter),
 )
+W_FlatIterator.acceptable_as_base_class = False
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -970,6 +970,11 @@
             s += k
         assert s == 140
 
+    def test_flatiter_array_conv(self):
+        from numpypy import array, dot
+        a = array([1, 2, 3])
+        assert dot(a.flat, a.flat) == 14
+
 class AppTestSupport(object):
     def setup_class(cls):
         import struct


More information about the pypy-commit mailing list