[pypy-commit] pypy numpy-refactor: fixes for where
fijal
noreply at buildbot.pypy.org
Tue Sep 11 14:30:12 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57266:1efcfcf768c7
Date: 2012-09-11 14:25 +0200
http://bitbucket.org/pypy/pypy/changeset/1efcfcf768c7/
Log: fixes for where
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -2,6 +2,7 @@
from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
from pypy.module.micronumpy import loop, interp_ufuncs
from pypy.module.micronumpy.iter import Chunk, Chunks
+from pypy.module.micronumpy.strides import shape_agreement
from pypy.interpreter.error import OperationError, operationerrfmt
from pypy.interpreter.gateway import unwrap_spec
@@ -65,15 +66,21 @@
NOTE: support for not passing x and y is unsupported
"""
- if space.is_w(w_x, space.w_None) or space.is_w(w_y, space.w_None):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "1-arg where unsupported right now"))
+ if space.is_w(w_y, space.w_None):
+ if space.is_w(w_x, space.w_None):
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ "1-arg where unsupported right now"))
+ raise OperationError(space.w_ValueError, space.wrap(
+ "Where should be called with either 1 or 3 arguments"))
arr = convert_to_array(space, w_arr)
x = convert_to_array(space, w_x)
y = convert_to_array(space, w_y)
- dtype = arr.get_dtype()
- out = W_NDimArray.from_shape(arr.get_shape(), dtype)
- return loop.where(out, arr, x, y, dtype)
+ dtype = interp_ufuncs.find_binop_result_dtype(space, x.get_dtype(),
+ y.get_dtype())
+ shape = shape_agreement(space, arr.get_shape(), x)
+ shape = shape_agreement(space, shape, y)
+ out = W_NDimArray.from_shape(shape, dtype)
+ return loop.where(out, shape, arr, x, y, dtype)
def dot(space, w_obj1, w_obj2):
w_arr = convert_to_array(space, w_obj1)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -69,18 +69,20 @@
arr_iter.setitem(box)
arr_iter.next()
-def where(out, arr, x, y, dtype):
- out_iter = out.create_iter()
- arr_iter = arr.create_iter()
- x_iter = x.create_iter()
- y_iter = y.create_iter()
- while not arr_iter.done():
+def where(out, shape, arr, x, y, dtype):
+ out_iter = out.create_iter(shape)
+ arr_iter = arr.create_iter(shape)
+ arr_dtype = arr.get_dtype()
+ x_iter = x.create_iter(shape)
+ y_iter = y.create_iter(shape)
+ while not x_iter.done():
w_cond = arr_iter.getitem()
- if dtype.itemtype.bool(w_cond):
+ if arr_dtype.itemtype.bool(w_cond):
w_val = x_iter.getitem().convert_to(dtype)
else:
w_val = y_iter.getitem().convert_to(dtype)
out_iter.setitem(w_val)
+ out_iter.next()
arr_iter.next()
x_iter.next()
y_iter.next()
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -6,17 +6,28 @@
from _numpypy import where, ones, zeros, array
a = [1, 2, 3, 0, -3]
a = where(array(a) > 0, ones(5), zeros(5))
- print a
assert (a == [1, 1, 1, 0, 0]).all()
def test_where_differing_dtypes(self):
- xxx
+ from _numpypy import array, ones, zeros, where
+ a = [1, 2, 3, 0, -3]
+ a = where(array(a) > 0, ones(5, dtype=int), zeros(5, dtype=float))
+ assert (a == [1, 1, 1, 0, 0]).all()
+
+ def test_where_broadcast(self):
+ from _numpypy import array, where
+ a = where(array([[1, 2, 3], [4, 5, 6]]) > 3, [1, 1, 1], 2)
+ assert (a == [[2, 2, 2], [1, 1, 1]]).all()
+ a = where(True, [1, 1, 1], 2)
+ assert (a == [1, 1, 1]).all()
def test_where_errors(self):
- xxx
+ from _numpypy import where, array
+ raises(ValueError, "where([1, 2, 3], [3, 4, 5])")
+ raises(ValueError, "where([1, 2, 3], [3, 4, 5], [6, 7])")
- def test_where_1_arg(self):
- xxx
+ #def test_where_1_arg(self):
+ # xxx
def test_where_invalidates(self):
from _numpypy import where, ones, zeros, array
More information about the pypy-commit
mailing list