[pypy-svn] r78075 - in pypy/branch/micronumpy-resync/pypy/module/micronumpy: . test

dan at codespeak.net dan at codespeak.net
Tue Oct 19 07:19:28 CEST 2010


Author: dan
Date: Tue Oct 19 07:19:25 2010
New Revision: 78075

Modified:
   pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py
   pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py
Log:
Added support code for array broadcasting, along with a small (unfinished) test.

Modified: pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py
==============================================================================
--- pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py	(original)
+++ pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py	Tue Oct 19 07:19:25 2010
@@ -32,6 +32,47 @@
     else:
         return 0
 
+def broadcast_shapes(a_shape, a_strides, b_shape, b_strides):
+    a_dim = len(a_shape)
+    b_dim = len(b_shape)
+
+    smaller_dim = a_dim if a_dim < b_dim else b_dim
+
+    if a_dim > b_dim:
+        result = a_shape
+        larger_dim = a_dim
+        smaller_dim = b_dim
+        shorter_strides = b_strides
+    else:
+        result = b_shape
+        larger_dim = b_dim
+        smaller_dim = a_dim
+        shorter_strides = a_strides
+
+    i_a = a_dim - 1
+    i_b = b_dim - 1
+    for i in range(smaller_dim):
+        assert i_a >= 0
+        a = a_shape[i_a]
+
+        assert i_b >= 0
+        b = b_shape[i_b]
+
+        if a == b or a == 1 or b == 1:
+            i_a -= 1
+            i_b -= 1
+            result[len(result) - 1 - i] = a if a > b else b
+        else:
+            raise ValueError("frames are not aligned") # FIXME: applevel?
+    
+    if a_dim < b_dim:
+        i_b += 1
+        a_strides = [0] * i_b + a_strides
+    else:
+        i_a += 1
+        b_strides = [0] * i_a + b_strides
+    return result, a_strides, b_strides
+
 def normalize_slice_starts(slice_starts, shape):
     for i in range(len(slice_starts)):
         if slice_starts[i] < 0:

Modified: pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py
==============================================================================
--- pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py	(original)
+++ pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py	Tue Oct 19 07:19:25 2010
@@ -425,6 +425,46 @@
         for w_xs, typecode in data:
             assert typecode == infer_from_iterable(space, w_xs).typecode
 
+class TestArraySupport(object):
+    def test_broadcast_shapes(self, space):
+        from pypy.module.micronumpy.array import broadcast_shapes
+        from pypy.module.micronumpy.array import stride_row as stride
+
+        def test(shape_a, shape_b, expected_result, expected_strides_a=None, expected_strides_b=None):
+            strides_a = [stride(shape_a, i) for i, x in enumerate(shape_a)]
+            strides_a_save = strides_a[:]
+
+            strides_b = [stride(shape_b, i) for i, x in enumerate(shape_b)]
+            strides_b_save = strides_b[:]
+
+            result_shape, result_strides_a, result_strides_b = broadcast_shapes(shape_a, strides_a, shape_b, strides_b)
+            assert result_shape == expected_result
+
+            if expected_strides_a:
+                assert result_strides_a == expected_strides_a
+            else:
+                assert result_strides_a == strides_a_save
+
+            if expected_strides_b:
+                assert result_strides_b == expected_strides_b
+            else:
+                assert result_strides_b == strides_b_save
+
+        shape_a = [256, 256, 3]
+        shape_b = [3]
+
+        test([256, 256, 3], [3],
+             expected_result=[256, 256, 3],
+             expected_strides_b=[0, 0, 1])
+
+        test([3], [256, 256, 3],
+             expected_result=[256, 256, 3],
+             expected_strides_a=[0, 0, 1])
+
+        test([8, 1, 6, 1], [7, 1, 5],
+             expected_result=[8, 7, 6, 5],
+             expected_strides_b=[0, 5, 5, 1])
+
 class TestMicroArray(object):
     @py.test.mark.xfail # XXX: return types changed
     def test_index2strides(self, space):



More information about the Pypy-commit mailing list