[pypy-commit] pypy default: add where to compile

fijal noreply at buildbot.pypy.org
Sun Apr 22 12:32:04 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: 
Changeset: r54614:ee17da202bea
Date: 2012-04-22 12:31 +0200
http://bitbucket.org/pypy/pypy/changeset/ee17da202bea/

Log:	add where to compile

diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -10,6 +10,7 @@
 from pypy.module.micronumpy.interp_dtype import get_dtype_cache
 from pypy.module.micronumpy.interp_numarray import (Scalar, BaseArray,
      scalar_w, W_NDimArray, array)
+from pypy.module.micronumpy.interp_arrayops import where
 from pypy.module.micronumpy import interp_ufuncs
 from pypy.rlib.objectmodel import specialize, instantiate
 
@@ -35,6 +36,7 @@
 SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
                         "unegative", "flat", "tostring"]
 TWO_ARG_FUNCTIONS = ["dot", 'take']
+THREE_ARG_FUNCTIONS = ['where']
 
 class FakeSpace(object):
     w_ValueError = None
@@ -445,14 +447,25 @@
             arg = self.args[1].execute(interp)
             if not isinstance(arg, BaseArray):
                 raise ArgumentNotAnArray
-            if not isinstance(arg, BaseArray):
-                raise ArgumentNotAnArray
             if self.name == "dot":
                 w_res = arr.descr_dot(interp.space, arg)
             elif self.name == 'take':
                 w_res = arr.descr_take(interp.space, arg)
             else:
                 assert False # unreachable code
+        elif self.name in THREE_ARG_FUNCTIONS:
+            if len(self.args) != 3:
+                raise ArgumentMismatch
+            arg1 = self.args[1].execute(interp)
+            arg2 = self.args[2].execute(interp)
+            if not isinstance(arg1, BaseArray):
+                raise ArgumentNotAnArray
+            if not isinstance(arg2, BaseArray):
+                raise ArgumentNotAnArray
+            if self.name == "where":
+                w_res = where(interp.space, arr, arg1, arg2)
+            else:
+                assert False
         else:
             raise WrongFunctionName
         if isinstance(w_res, BaseArray):
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -270,3 +270,13 @@
         b -> 2
         """)
         assert interp.results[0].value == 3
+
+    def test_where(self):
+        interp = self.run('''
+        a = [1, 0, 3, 0]
+        b = [1, 1, 1, 1]
+        c = [0, 0, 0, 0]
+        d = where(a, b, c)
+        d -> 1
+        ''')
+        assert interp.results[0].value == 0


More information about the pypy-commit mailing list