[pypy-commit] pypy numpy-back-to-applevel: implement take for test_compile, also be more explicit about wrapping
fijal
noreply at buildbot.pypy.org
Fri Jan 27 12:03:00 CET 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-back-to-applevel
Changeset: r51848:2ba6d0106f54
Date: 2012-01-27 13:02 +0200
http://bitbucket.org/pypy/pypy/changeset/2ba6d0106f54/
Log: implement take for test_compile, also be more explicit about
wrapping
diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -33,6 +33,7 @@
pass
SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any", "unegative"]
+TWO_ARG_FUNCTIONS = ['take']
class FakeSpace(object):
w_ValueError = None
@@ -372,12 +373,12 @@
for arg in self.args]))
def execute(self, interp):
+ arr = self.args[0].execute(interp)
+ if not isinstance(arr, BaseArray):
+ raise ArgumentNotAnArray
if self.name in SINGLE_ARG_FUNCTIONS:
if len(self.args) != 1 and self.name != 'sum':
raise ArgumentMismatch
- arr = self.args[0].execute(interp)
- if not isinstance(arr, BaseArray):
- raise ArgumentNotAnArray
if self.name == "sum":
if len(self.args)>1:
w_res = arr.descr_sum(interp.space,
@@ -399,19 +400,27 @@
w_res = neg.call(interp.space, [arr])
else:
assert False # unreachable code
- if isinstance(w_res, BaseArray):
- return w_res
- if isinstance(w_res, FloatObject):
- dtype = get_dtype_cache(interp.space).w_float64dtype
- elif isinstance(w_res, BoolObject):
- dtype = get_dtype_cache(interp.space).w_booldtype
- elif isinstance(w_res, interp_boxes.W_GenericBox):
- dtype = w_res.get_dtype(interp.space)
+ elif self.name in TWO_ARG_FUNCTIONS:
+ arg = self.args[1].execute(interp)
+ if not isinstance(arg, BaseArray):
+ raise ArgumentNotAnArray
+ if self.name == 'take':
+ w_res = arr.descr_take(interp.space, arg)
else:
- dtype = None
- return scalar_w(interp.space, dtype, w_res)
+ assert False # unreachable
else:
raise WrongFunctionName
+ if isinstance(w_res, BaseArray):
+ return w_res
+ if isinstance(w_res, FloatObject):
+ dtype = get_dtype_cache(interp.space).w_float64dtype
+ elif isinstance(w_res, BoolObject):
+ dtype = get_dtype_cache(interp.space).w_booldtype
+ elif isinstance(w_res, interp_boxes.W_GenericBox):
+ dtype = w_res.get_dtype(interp.space)
+ else:
+ dtype = None
+ return scalar_w(interp.space, dtype, w_res)
_REGEXES = [
('-?[\d\.]+', 'number'),
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -61,6 +61,11 @@
reds=['idx', 'idxi', 'frame', 'arr'],
name='numpy_filterset',
)
+take_driver = jit.JitDriver(
+ greens=['shapelen', 'sig'],
+ reds=['index_i', 'res_i', 'concr'],
+ name='numpy_take',
+)
class BaseArray(Wrappable):
_attrs_ = ["invalidates", "shape", 'size']
@@ -611,9 +616,14 @@
res_i = res.create_iter()
longdtype = interp_dtype.get_dtype_cache(space).w_longdtype
shapelen = len(index.shape)
+ sig = concr.find_sig()
while not index_i.done():
+ take_driver.jit_merge_point(index_i=index_i,
+ res_i=res_i, concr=concr,
+ shapelen=shapelen, sig=sig)
# XXX jitdriver + test_zjit
- w_item = index.getitem(index_i.offset).convert_to(longdtype)
+ w_item = index.getitem(index_i.offset).convert_to(longdtype).item(
+ space)
res.setitem(res_i.offset, concr.descr_getitem(space, w_item))
index_i = index_i.next(shapelen)
res_i = res_i.next(shapelen)
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -245,3 +245,11 @@
a -> 3
""")
assert interp.results[0].value == 11
+
+ def test_take(self):
+ interp = self.run("""
+ a = |10|
+ b = take(a, [1, 1, 3, 2])
+ b -> 2
+ """)
+ assert interp.results[0].value == 3
More information about the pypy-commit
mailing list