[pypy-commit] pypy refactor-signature: reinstitute broadcast - no code addition

fijal noreply at buildbot.pypy.org
Mon Dec 19 13:52:03 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: refactor-signature
Changeset: r50705:0ce8fad59c36
Date: 2011-12-19 14:51 +0200
http://bitbucket.org/pypy/pypy/changeset/0ce8fad59c36/

Log:	reinstitute broadcast - no code addition

diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -62,7 +62,6 @@
 
 class ViewIterator(BaseIterator):
     def __init__(self, arr, res_shape=None):
-        self.indices = [0] * len(arr.shape)
         self.offset  = arr.start
         self._done   = False
         if res_shape is not None and res_shape != arr.shape:
@@ -82,6 +81,7 @@
             self.strides = arr.strides
             self.backstrides = arr.backstrides
             self.res_shape = arr.shape
+        self.indices = [0] * len(self.res_shape)
 
 
     @jit.unroll_safe
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -842,7 +842,7 @@
 
     def create_sig(self, res_shape):
         if self.forced_result is not None:
-            return signature.ArraySignature(self.forced_result.dtype)
+            return self.forced_result.array_sig(res_shape)
         return signature.Call1(self.ufunc, self.name,
                                self.values.create_sig(res_shape))
 
@@ -869,7 +869,7 @@
 
     def create_sig(self, res_shape):
         if self.forced_result is not None:
-            return signature.ArraySignature(self.forced_result.dtype)
+            return self.forced_result.array_sig(res_shape)
         return signature.Call2(self.ufunc, self.name, self.calc_dtype,
                                self.left.create_sig(res_shape),
                                self.right.create_sig(res_shape))
@@ -930,6 +930,11 @@
         self.strides = strides[:]
         self.backstrides = backstrides[:]
 
+    def array_sig(self, res_shape):
+        if res_shape is not None and self.shape != res_shape:
+            return signature.ViewSignature(self.dtype)
+        return signature.ArraySignature(self.dtype)
+
 class W_NDimSlice(ConcreteArray):
     def __init__(self, start, strides, backstrides, shape, parent):
         if isinstance(parent, W_NDimSlice):
@@ -949,7 +954,7 @@
 
     def _sliceloop(self, source, res_shape):
         sig = source.find_sig(res_shape)
-        frame = sig.create_frame(source)
+        frame = sig.create_frame(source, res_shape)
         res_iter = ViewIterator(self)
         shapelen = len(res_shape)
         while not res_iter.done():
@@ -1028,7 +1033,7 @@
         self.calc_strides(new_shape)
 
     def create_sig(self, res_shape):
-        return signature.ArraySignature(self.dtype)
+        return self.array_sig(res_shape)
 
     def __del__(self):
         lltype.free(self.storage, flavor='raw', track_allocation=False)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -88,10 +88,11 @@
             allnumbers.append(no)
         self.iter_no = no
 
-    def create_frame(self, arr):
+    def create_frame(self, arr, res_shape=None):
+        res_shape = res_shape or arr.shape
         iterlist = []
         arraylist = []
-        self._create_iter(iterlist, arraylist, arr)
+        self._create_iter(iterlist, arraylist, arr, res_shape)
         return NumpyEvalFrame(iterlist, arraylist)
 
 class ConcreteSignature(Signature):
@@ -120,14 +121,14 @@
         storage = arr.get_concrete().storage
         self.array_no = _add_ptr_to_cache(storage, cache)
 
-    def _create_iter(self, iterlist, arraylist, arr):
+    def _create_iter(self, iterlist, arraylist, arr, res_shape):
         storage = arr.get_concrete().storage
         if self.iter_no >= len(iterlist):
-            iterlist.append(self.allocate_iter(arr))
+            iterlist.append(self.allocate_iter(arr, res_shape))
         if self.array_no >= len(arraylist):
             arraylist.append(storage)
 
-    def allocate_iter(self, arr):
+    def allocate_iter(self, arr, res_shape):
         return ArrayIterator(arr.size)
 
     def eval(self, frame, arr):
@@ -141,7 +142,7 @@
     def _invent_array_numbering(self, arr, cache):
         pass
 
-    def _create_iter(self, iterlist, arraylist, arr):
+    def _create_iter(self, iterlist, arraylist, arr, res_shape):
         if self.iter_no >= len(iterlist):
             iter = ConstantIterator()
             iterlist.append(iter)
@@ -161,14 +162,14 @@
         allnumbers.append(no)
         self.iter_no = no
 
-    def allocate_iter(self, arr):
-        return ViewIterator(arr)
+    def allocate_iter(self, arr, res_shape):
+        return ViewIterator(arr, res_shape)
 
 class FlatiterSignature(ViewSignature):
     def debug_repr(self):
         return 'FlatIter(%s)' % self.child.debug_repr()
 
-    def _create_iter(self, iterlist, arraylist, arr):
+    def _create_iter(self, iterlist, arraylist, arr, res_shape):
         raise NotImplementedError
 
 class Call1(Signature):
@@ -200,10 +201,10 @@
         assert isinstance(arr, Call1)
         self.child._invent_array_numbering(arr.values, cache)
 
-    def _create_iter(self, iterlist, arraylist, arr):
+    def _create_iter(self, iterlist, arraylist, arr, res_shape):
         from pypy.module.micronumpy.interp_numarray import Call1
         assert isinstance(arr, Call1)
-        self.child._create_iter(iterlist, arraylist, arr.values)
+        self.child._create_iter(iterlist, arraylist, arr.values, res_shape)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call1
@@ -244,12 +245,12 @@
         self.left._invent_numbering(cache, allnumbers)
         self.right._invent_numbering(cache, allnumbers)
 
-    def _create_iter(self, iterlist, arraylist, arr):
+    def _create_iter(self, iterlist, arraylist, arr, res_shape):
         from pypy.module.micronumpy.interp_numarray import Call2
         
         assert isinstance(arr, Call2)
-        self.left._create_iter(iterlist, arraylist, arr.left)
-        self.right._create_iter(iterlist, arraylist, arr.right)
+        self.left._create_iter(iterlist, arraylist, arr.left, res_shape)
+        self.right._create_iter(iterlist, arraylist, arr.right, res_shape)
 
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call2
@@ -263,8 +264,8 @@
                                   self.right.debug_repr())
 
 class ReduceSignature(Call2):
-    def _create_iter(self, iterlist, arraylist, arr):
-        self.right._create_iter(iterlist, arraylist, arr)
+    def _create_iter(self, iterlist, arraylist, arr, res_shape):
+        self.right._create_iter(iterlist, arraylist, arr, res_shape)
 
     def _invent_numbering(self, cache, allnumbers):
         self.right._invent_numbering(cache, allnumbers)
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -865,7 +865,6 @@
         assert (a == [8, 6, 4, 2, 0]).all()
 
     def test_debug_repr(self):
-        skip("for now")
         from numpypy import zeros, sin
         a = zeros(1)
         assert a.__debug_repr__() == 'Array'
@@ -1001,7 +1000,6 @@
         assert a[0, 1, 2] == 1.0
 
     def test_broadcast_ufunc(self):
-        skip("broadcast unsupported")
         from numpypy import array
         a = array([[1, 2], [3, 4], [5, 6]])
         b = array([5, 6])
@@ -1009,15 +1007,13 @@
         assert c.all()
 
     def test_broadcast_setslice(self):
-        skip("broadcast unsupported")
         from numpypy import zeros, ones
-        a = zeros((100, 100))
-        b = ones(100)
+        a = zeros((10, 10))
+        b = ones(10)
         a[:, :] = b
-        assert a[13, 15] == 1
+        assert a[3, 5] == 1
 
     def test_broadcast_shape_agreement(self):
-        skip("broadcast unsupported")
         from numpypy import zeros, array
         a = zeros((3, 1, 3))
         b = array(((10, 11, 12), (20, 21, 22), (30, 31, 32)))
@@ -1032,7 +1028,6 @@
         assert c.all()
 
     def test_broadcast_scalar(self):
-        skip("broadcast unsupported")
         from numpypy import zeros
         a = zeros((4, 5), 'd')
         a[:, 1] = 3
@@ -1044,7 +1039,6 @@
         assert a[3, 2] == 0
 
     def test_broadcast_call2(self):
-        skip("broadcast unsupported")
         from numpypy import zeros, ones
         a = zeros((4, 1, 5))
         b = ones((4, 3, 5))


More information about the pypy-commit mailing list