[pypy-commit] pypy default: fix nditer getitem return types

bdkearns noreply at buildbot.pypy.org
Fri Apr 18 18:45:34 CEST 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r70756:f2b222d84fb1
Date: 2014-04-18 12:29 -0400
http://bitbucket.org/pypy/pypy/changeset/f2b222d84fb1/

Log:	fix nditer getitem return types

diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -2,9 +2,8 @@
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
 from pypy.interpreter.error import OperationError, oefmt
-from pypy.module.micronumpy import ufuncs, support
+from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
-from pypy.module.micronumpy.concrete import SliceArray
 from pypy.module.micronumpy.descriptor import decode_w_dtype
 from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
@@ -25,7 +24,8 @@
 class IteratorMixin(object):
     _mixin_ = True
 
-    def __init__(self, it, op_flags):
+    def __init__(self, nditer, it, op_flags):
+        self.nditer = nditer
         self.it = it
         self.st = it.reset()
         self.op_flags = op_flags
@@ -37,7 +37,7 @@
         self.st = self.it.next(self.st)
 
     def getitem(self, space, array):
-        return self.op_flags.get_it_item[self.index](space, array, self.it, self.st)
+        return self.op_flags.get_it_item[self.index](space, self.nditer, self.it, self.st)
 
     def setitem(self, space, array, val):
         xxx
@@ -90,14 +90,17 @@
         self.get_it_item = (get_readonly_item, get_readonly_slice)
 
 
-def get_readonly_item(space, array, it, st):
-    return space.wrap(it.getitem(st))
+def get_readonly_item(space, nditer, it, st):
+    res = concrete.ConcreteNonWritableArrayWithBase(
+        [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
+    res.start = st.offset
+    return W_NDimArray(res)
 
 
-def get_readwrite_item(space, array, it, st):
-    #create a single-value view (since scalars are not views)
-    res = SliceArray(it.array.start + st.offset, [0], [0], [1], it.array, array)
-    #it.dtype.setitem(res, 0, it.getitem())
+def get_readwrite_item(space, nditer, it, st):
+    res = concrete.ConcreteArrayWithBase(
+        [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
+    res.start = st.offset
     return W_NDimArray(res)
 
 
@@ -398,12 +401,14 @@
         if self.external_loop:
             for i in range(len(self.seq)):
                 self.iters.append(ExternalLoopIterator(
+                    self,
                     get_external_loop_iter(
                         space, self.order, self.seq[i], iter_shape),
                     self.op_flags[i]))
         else:
             for i in range(len(self.seq)):
                 self.iters.append(BoxIterator(
+                    self,
                     get_iter(
                         space, self.order, self.seq[i], iter_shape, self.dtypes[i]),
                     self.op_flags[i]))
diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -4,14 +4,20 @@
 
 class AppTestNDIter(BaseNumpyAppTest):
     def test_basic(self):
-        from numpy import arange, nditer
+        from numpy import arange, nditer, ndarray
         a = arange(6).reshape(2,3)
+        i = nditer(a)
         r = []
-        for x in nditer(a):
+        for x in i:
+            assert type(x) is ndarray
+            assert x.base is i
+            assert x.shape == ()
+            assert x.strides == ()
+            exc = raises(ValueError, "x[()] = 42")
+            assert str(exc.value) == 'assignment destination is read-only'
             r.append(x)
         assert r == [0, 1, 2, 3, 4, 5]
         r = []
-
         for x in nditer(a.T):
             r.append(x)
         assert r == [0, 1, 2, 3, 4, 5]
@@ -29,9 +35,14 @@
         assert r == [0, 3, 1, 4, 2, 5]
 
     def test_readwrite(self):
-        from numpy import arange, nditer
+        from numpy import arange, nditer, ndarray
         a = arange(6).reshape(2,3)
-        for x in nditer(a, op_flags=['readwrite']):
+        i = nditer(a, op_flags=['readwrite'])
+        for x in i:
+            assert type(x) is ndarray
+            assert x.base is i
+            assert x.shape == ()
+            assert x.strides == ()
             x[...] = 2 * x
         assert (a == [[0, 2, 4], [6, 8, 10]]).all()
 


More information about the pypy-commit mailing list