[pypy-svn] r78075 - in pypy/branch/micronumpy-resync/pypy/module/micronumpy: . test
dan at codespeak.net
dan at codespeak.net
Tue Oct 19 07:19:28 CEST 2010
Author: dan
Date: Tue Oct 19 07:19:25 2010
New Revision: 78075
Modified:
pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py
pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py
Log:
Added support code for array broadcasting, along with a small (unfinished) test.
Modified: pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py
==============================================================================
--- pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py (original)
+++ pypy/branch/micronumpy-resync/pypy/module/micronumpy/array.py Tue Oct 19 07:19:25 2010
@@ -32,6 +32,47 @@
else:
return 0
+def broadcast_shapes(a_shape, a_strides, b_shape, b_strides):
+ a_dim = len(a_shape)
+ b_dim = len(b_shape)
+
+ smaller_dim = a_dim if a_dim < b_dim else b_dim
+
+ if a_dim > b_dim:
+ result = a_shape
+ larger_dim = a_dim
+ smaller_dim = b_dim
+ shorter_strides = b_strides
+ else:
+ result = b_shape
+ larger_dim = b_dim
+ smaller_dim = a_dim
+ shorter_strides = a_strides
+
+ i_a = a_dim - 1
+ i_b = b_dim - 1
+ for i in range(smaller_dim):
+ assert i_a >= 0
+ a = a_shape[i_a]
+
+ assert i_b >= 0
+ b = b_shape[i_b]
+
+ if a == b or a == 1 or b == 1:
+ i_a -= 1
+ i_b -= 1
+ result[len(result) - 1 - i] = a if a > b else b
+ else:
+ raise ValueError("frames are not aligned") # FIXME: applevel?
+
+ if a_dim < b_dim:
+ i_b += 1
+ a_strides = [0] * i_b + a_strides
+ else:
+ i_a += 1
+ b_strides = [0] * i_a + b_strides
+ return result, a_strides, b_strides
+
def normalize_slice_starts(slice_starts, shape):
for i in range(len(slice_starts)):
if slice_starts[i] < 0:
Modified: pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py
==============================================================================
--- pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py (original)
+++ pypy/branch/micronumpy-resync/pypy/module/micronumpy/test/test_numpy.py Tue Oct 19 07:19:25 2010
@@ -425,6 +425,46 @@
for w_xs, typecode in data:
assert typecode == infer_from_iterable(space, w_xs).typecode
+class TestArraySupport(object):
+ def test_broadcast_shapes(self, space):
+ from pypy.module.micronumpy.array import broadcast_shapes
+ from pypy.module.micronumpy.array import stride_row as stride
+
+ def test(shape_a, shape_b, expected_result, expected_strides_a=None, expected_strides_b=None):
+ strides_a = [stride(shape_a, i) for i, x in enumerate(shape_a)]
+ strides_a_save = strides_a[:]
+
+ strides_b = [stride(shape_b, i) for i, x in enumerate(shape_b)]
+ strides_b_save = strides_b[:]
+
+ result_shape, result_strides_a, result_strides_b = broadcast_shapes(shape_a, strides_a, shape_b, strides_b)
+ assert result_shape == expected_result
+
+ if expected_strides_a:
+ assert result_strides_a == expected_strides_a
+ else:
+ assert result_strides_a == strides_a_save
+
+ if expected_strides_b:
+ assert result_strides_b == expected_strides_b
+ else:
+ assert result_strides_b == strides_b_save
+
+ shape_a = [256, 256, 3]
+ shape_b = [3]
+
+ test([256, 256, 3], [3],
+ expected_result=[256, 256, 3],
+ expected_strides_b=[0, 0, 1])
+
+ test([3], [256, 256, 3],
+ expected_result=[256, 256, 3],
+ expected_strides_a=[0, 0, 1])
+
+ test([8, 1, 6, 1], [7, 1, 5],
+ expected_result=[8, 7, 6, 5],
+ expected_strides_b=[0, 5, 5, 1])
+
class TestMicroArray(object):
@py.test.mark.xfail # XXX: return types changed
def test_index2strides(self, space):
More information about the Pypy-commit
mailing list