[pypy-commit] pypy default: add where to compile
fijal
noreply at buildbot.pypy.org
Sun Apr 22 12:32:04 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch:
Changeset: r54614:ee17da202bea
Date: 2012-04-22 12:31 +0200
http://bitbucket.org/pypy/pypy/changeset/ee17da202bea/
Log: add where to compile
diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -10,6 +10,7 @@
from pypy.module.micronumpy.interp_dtype import get_dtype_cache
from pypy.module.micronumpy.interp_numarray import (Scalar, BaseArray,
scalar_w, W_NDimArray, array)
+from pypy.module.micronumpy.interp_arrayops import where
from pypy.module.micronumpy import interp_ufuncs
from pypy.rlib.objectmodel import specialize, instantiate
@@ -35,6 +36,7 @@
SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
"unegative", "flat", "tostring"]
TWO_ARG_FUNCTIONS = ["dot", 'take']
+THREE_ARG_FUNCTIONS = ['where']
class FakeSpace(object):
w_ValueError = None
@@ -445,14 +447,25 @@
arg = self.args[1].execute(interp)
if not isinstance(arg, BaseArray):
raise ArgumentNotAnArray
- if not isinstance(arg, BaseArray):
- raise ArgumentNotAnArray
if self.name == "dot":
w_res = arr.descr_dot(interp.space, arg)
elif self.name == 'take':
w_res = arr.descr_take(interp.space, arg)
else:
assert False # unreachable code
+ elif self.name in THREE_ARG_FUNCTIONS:
+ if len(self.args) != 3:
+ raise ArgumentMismatch
+ arg1 = self.args[1].execute(interp)
+ arg2 = self.args[2].execute(interp)
+ if not isinstance(arg1, BaseArray):
+ raise ArgumentNotAnArray
+ if not isinstance(arg2, BaseArray):
+ raise ArgumentNotAnArray
+ if self.name == "where":
+ w_res = where(interp.space, arr, arg1, arg2)
+ else:
+ assert False
else:
raise WrongFunctionName
if isinstance(w_res, BaseArray):
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -270,3 +270,13 @@
b -> 2
""")
assert interp.results[0].value == 3
+
+ def test_where(self):
+ interp = self.run('''
+ a = [1, 0, 3, 0]
+ b = [1, 1, 1, 1]
+ c = [0, 0, 0, 0]
+ d = where(a, b, c)
+ d -> 1
+ ''')
+ assert interp.results[0].value == 0
More information about the pypy-commit
mailing list