[pypy-commit] pypy numpy-back-to-applevel: implement take for test_compile, also be more explicit about wrapping

fijal noreply at buildbot.pypy.org
Fri Jan 27 12:03:00 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-back-to-applevel
Changeset: r51848:2ba6d0106f54
Date: 2012-01-27 13:02 +0200
http://bitbucket.org/pypy/pypy/changeset/2ba6d0106f54/

Log:	implement take for test_compile, also be more explicit about
	wrapping

diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -33,6 +33,7 @@
     pass
 
 SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any", "unegative"]
+TWO_ARG_FUNCTIONS = ['take']
 
 class FakeSpace(object):
     w_ValueError = None
@@ -372,12 +373,12 @@
                                                  for arg in self.args]))
 
     def execute(self, interp):
+        arr = self.args[0].execute(interp)
+        if not isinstance(arr, BaseArray):
+            raise ArgumentNotAnArray
         if self.name in SINGLE_ARG_FUNCTIONS:
             if len(self.args) != 1 and self.name != 'sum':
                 raise ArgumentMismatch
-            arr = self.args[0].execute(interp)
-            if not isinstance(arr, BaseArray):
-                raise ArgumentNotAnArray
             if self.name == "sum":
                 if len(self.args)>1:
                     w_res = arr.descr_sum(interp.space,
@@ -399,19 +400,27 @@
                 w_res = neg.call(interp.space, [arr])
             else:
                 assert False # unreachable code
-            if isinstance(w_res, BaseArray):
-                return w_res
-            if isinstance(w_res, FloatObject):
-                dtype = get_dtype_cache(interp.space).w_float64dtype
-            elif isinstance(w_res, BoolObject):
-                dtype = get_dtype_cache(interp.space).w_booldtype
-            elif isinstance(w_res, interp_boxes.W_GenericBox):
-                dtype = w_res.get_dtype(interp.space)
+        elif self.name in TWO_ARG_FUNCTIONS:
+            arg = self.args[1].execute(interp)
+            if not isinstance(arg, BaseArray):
+                raise ArgumentNotAnArray
+            if self.name == 'take':
+                w_res = arr.descr_take(interp.space, arg)
             else:
-                dtype = None
-            return scalar_w(interp.space, dtype, w_res)
+                assert False # unreachable
         else:
             raise WrongFunctionName
+        if isinstance(w_res, BaseArray):
+            return w_res
+        if isinstance(w_res, FloatObject):
+            dtype = get_dtype_cache(interp.space).w_float64dtype
+        elif isinstance(w_res, BoolObject):
+            dtype = get_dtype_cache(interp.space).w_booldtype
+        elif isinstance(w_res, interp_boxes.W_GenericBox):
+            dtype = w_res.get_dtype(interp.space)
+        else:
+            dtype = None
+        return scalar_w(interp.space, dtype, w_res)
 
 _REGEXES = [
     ('-?[\d\.]+', 'number'),
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -61,6 +61,11 @@
     reds=['idx', 'idxi', 'frame', 'arr'],
     name='numpy_filterset',
 )
+take_driver = jit.JitDriver(
+    greens=['shapelen', 'sig'],
+    reds=['index_i', 'res_i', 'concr'],
+    name='numpy_take',
+)
 
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "shape", 'size']
@@ -611,9 +616,14 @@
         res_i = res.create_iter()
         longdtype = interp_dtype.get_dtype_cache(space).w_longdtype
         shapelen = len(index.shape)
+        sig = concr.find_sig()
         while not index_i.done():
+            take_driver.jit_merge_point(index_i=index_i,
+                                        res_i=res_i, concr=concr,
+                                        shapelen=shapelen, sig=sig)
             # XXX jitdriver + test_zjit
-            w_item = index.getitem(index_i.offset).convert_to(longdtype)
+            w_item = index.getitem(index_i.offset).convert_to(longdtype).item(
+                space)
             res.setitem(res_i.offset, concr.descr_getitem(space, w_item))
             index_i = index_i.next(shapelen)
             res_i = res_i.next(shapelen)
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -245,3 +245,11 @@
         a -> 3
         """)
         assert interp.results[0].value == 11
+
+    def test_take(self):
+        interp = self.run("""
+        a = |10|
+        b = take(a, [1, 1, 3, 2])
+        b -> 2
+        """)
+        assert interp.results[0].value == 3


More information about the pypy-commit mailing list