[pypy-commit] pypy default: hopefully make flatiter interoperate nicely with the rest by making it a part
fijal
noreply at buildbot.pypy.org
Mon Nov 28 11:20:11 CET 2011
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch:
Changeset: r49886:9c707f4a6aa4
Date: 2011-11-28 12:15 +0200
http://bitbucket.org/pypy/pypy/changeset/9c707f4a6aa4/
Log: hopefully make flatiter interoperate nicely with the rest by making
it a part of BaseArray hierarchy
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -178,6 +178,25 @@
def get_offset(self):
return self.offset
+class OneDimIterator(BaseIterator):
+ def __init__(self, start, step, size):
+ self.offset = start
+ self.step = step
+ self.size = size
+
+ def next(self, shapelen):
+ arr = instantiate(OneDimIterator)
+ arr.size = self.size
+ arr.step = self.step
+ arr.offset = self.offset + self.step
+ return arr
+
+ def done(self):
+ return self.offset >= self.size
+
+ def get_offset(self):
+ return self.offset
+
class ViewIterator(BaseIterator):
def __init__(self, arr):
self.indices = [0] * len(arr.shape)
@@ -1218,18 +1237,39 @@
)
-class W_FlatIterator(Wrappable):
- _immutable_fields_ = ['shapelen', 'arr']
+class W_FlatIterator(ViewArray):
+ signature = signature.BaseSignature()
+
+ @jit.unroll_safe
+ def __init__(self, arr):
+ size = 1
+ for sh in arr.shape:
+ size *= sh
+ new_sig = signature.Signature.find_sig([
+ W_FlatIterator.signature, arr.signature
+ ])
+ ViewArray.__init__(self, arr, new_sig, [arr.strides[-1]],
+ [arr.backstrides[-1]], [size])
+ self.shapelen = len(arr.shape)
+ self.arr = arr
+ self.iter = self.start_iter()
- def __init__(self, arr):
- self.arr = arr.get_concrete()
- self.iter = arr.start_iter()
- self.shapelen = len(arr.shape)
+ def start_iter(self, res_shape=None):
+ if res_shape is not None and res_shape != self.shape:
+ return BroadcastIterator(self, res_shape)
+ return OneDimIterator(self.arr.start, self.strides[0],
+ self.shape[0])
+
+ def find_dtype(self):
+ return self.arr.find_dtype()
+
+ def find_size(self):
+ return self.shape[0]
def descr_next(self, space):
if self.iter.done():
raise OperationError(space.w_StopIteration, space.wrap(''))
- result = self.arr.eval(self.iter)
+ result = self.eval(self.iter)
self.iter = self.iter.next(self.shapelen)
return result.wrap(space)
@@ -1242,3 +1282,4 @@
next = interp2app(W_FlatIterator.descr_next),
__iter__ = interp2app(W_FlatIterator.descr_iter),
)
+W_FlatIterator.acceptable_as_base_class = False
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -970,6 +970,11 @@
s += k
assert s == 140
+ def test_flatiter_array_conv(self):
+ from numpypy import array, dot
+ a = array([1, 2, 3])
+ assert dot(a.flat, a.flat) == 14
+
class AppTestSupport(object):
def setup_class(cls):
import struct
More information about the pypy-commit
mailing list