[pypy-commit] pypy default: implement searchsorted in rpython with jitdriver
bdkearns
noreply at buildbot.pypy.org
Fri Oct 10 06:25:45 CEST 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r73877:0462e4a83ff1
Date: 2014-10-09 18:03 -0400
http://bitbucket.org/pypy/pypy/changeset/0462e4a83ff1/
Log: implement searchsorted in rpython with jitdriver
diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -36,7 +36,7 @@
SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
"unegative", "flat", "tostring","count_nonzero",
"argsort"]
-TWO_ARG_FUNCTIONS = ["dot", 'take']
+TWO_ARG_FUNCTIONS = ["dot", 'take', 'searchsorted']
TWO_ARG_FUNCTIONS_OR_NONE = ['view', 'astype']
THREE_ARG_FUNCTIONS = ['where']
@@ -109,6 +109,9 @@
if stop < 0:
stop += size + 1
if step < 0:
+ start, stop = stop, start
+ start -= 1
+ stop -= 1
lgt = (stop - start + 1) / step + 1
else:
lgt = (stop - start - 1) / step + 1
@@ -475,7 +478,6 @@
class SliceConstant(Node):
def __init__(self, start, stop, step):
- # no negative support for now
self.start = start
self.stop = stop
self.step = step
@@ -582,6 +584,9 @@
w_res = arr.descr_dot(interp.space, arg)
elif self.name == 'take':
w_res = arr.descr_take(interp.space, arg)
+ elif self.name == "searchsorted":
+ w_res = arr.descr_searchsorted(interp.space, arg,
+ interp.space.wrap('left'))
else:
assert False # unreachable code
elif self.name in THREE_ARG_FUNCTIONS:
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -700,3 +700,43 @@
out_iter.setitem(out_state, arr.getitem_index(space, indexes))
iter.next()
out_state = out_iter.next(out_state)
+
+def _new_binsearch(side, op_name):
+ binsearch_driver = jit.JitDriver(name='numpy_binsearch_' + side,
+ greens=['dtype'],
+ reds='auto')
+
+ def binsearch(space, arr, key, ret):
+ assert len(arr.get_shape()) == 1
+ dtype = key.get_dtype()
+ op = getattr(dtype.itemtype, op_name)
+ key_iter, key_state = key.create_iter()
+ ret_iter, ret_state = ret.create_iter()
+ ret_iter.track_index = False
+ size = arr.get_size()
+ min_idx = 0
+ max_idx = size
+ last_key_val = key_iter.getitem(key_state)
+ while not key_iter.done(key_state):
+ key_val = key_iter.getitem(key_state)
+ if dtype.itemtype.lt(last_key_val, key_val):
+ max_idx = size
+ else:
+ min_idx = 0
+ max_idx = max_idx + 1 if max_idx < size else size
+ last_key_val = key_val
+ while min_idx < max_idx:
+ binsearch_driver.jit_merge_point(dtype=dtype)
+ mid_idx = min_idx + ((max_idx - min_idx) >> 1)
+ mid_val = arr.getitem(space, [mid_idx]).convert_to(space, dtype)
+ if op(mid_val, key_val):
+ min_idx = mid_idx + 1
+ else:
+ max_idx = mid_idx
+ ret_iter.setitem(ret_state, ret.get_dtype().box(min_idx))
+ ret_state = ret_iter.next(ret_state)
+ key_state = key_iter.next(key_state)
+ return binsearch
+
+binsearch_left = _new_binsearch('left', 'lt')
+binsearch_right = _new_binsearch('right', 'le')
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -20,7 +20,6 @@
from pypy.module.micronumpy.flagsobj import W_FlagsObject
from pypy.module.micronumpy.strides import get_shape_from_iterable, \
shape_agreement, shape_agreement_multiple
-from .selection import app_searchsort
def _match_dot_shapes(space, left, right):
@@ -740,7 +739,11 @@
v = convert_to_array(space, w_v)
ret = W_NDimArray.from_shape(
space, v.get_shape(), descriptor.get_dtype_cache(space).w_longdtype)
- app_searchsort(space, self, v, space.wrap(side), ret)
+ if side == NPY.SEARCHLEFT:
+ binsearch = loop.binsearch_left
+ else:
+ binsearch = loop.binsearch_right
+ binsearch(space, self, v, ret)
if ret.is_scalar():
return ret.get_scalar_value()
return ret
diff --git a/pypy/module/micronumpy/selection.py b/pypy/module/micronumpy/selection.py
--- a/pypy/module/micronumpy/selection.py
+++ b/pypy/module/micronumpy/selection.py
@@ -1,5 +1,4 @@
from pypy.interpreter.error import oefmt
-from pypy.interpreter.gateway import applevel
from rpython.rlib.listsort import make_timsort_class
from rpython.rlib.objectmodel import specialize
from rpython.rlib.rarithmetic import widen
@@ -354,39 +353,3 @@
cache[cls] = make_sort_function(space, cls, it)
self.cache = cache
self._lookup = specialize.memo()(lambda tp: cache[tp[0]])
-
-
-app_searchsort = applevel(r"""
- import operator
-
- def searchsort(arr, val, side, res):
- val = val.flat
- res = res.flat
- if side == 0:
- op = operator.lt
- else:
- op = operator.le
-
- size = arr.size
- imin = 0
- imax = size
- try:
- last = val[0]
- except IndexError:
- return
- for i in xrange(len(val)):
- key = val[i]
- if last < key:
- imax = size
- else:
- imin = 0
- imax = imax + 1 if imax < size else size
- last = key
- while imin < imax:
- imid = imin + ((imax - imin) >> 1)
- if op(arr[imid], key):
- imin = imid + 1
- else:
- imax = imid
- res[i] = imin
-""", filename=__file__).interphook('searchsort')
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -330,3 +330,12 @@
results = interp.results[0]
assert isinstance(results, W_NDimArray)
assert results.get_dtype().is_int()
+
+ def test_searchsorted(self):
+ interp = self.run('''
+ a = [1, 4, 5, 6, 9]
+ b = |30| -> ::-1
+ c = searchsorted(a, b)
+ c -> -1
+ ''')
+ assert interp.results[0].value == 0
diff --git a/pypy/module/micronumpy/test/test_selection.py b/pypy/module/micronumpy/test/test_selection.py
--- a/pypy/module/micronumpy/test/test_selection.py
+++ b/pypy/module/micronumpy/test/test_selection.py
@@ -382,6 +382,9 @@
assert ret == 3
assert isinstance(ret, np.generic)
+ assert a.searchsorted(3.1) == 3
+ assert a.searchsorted(3.9) == 3
+
exc = raises(ValueError, a.searchsorted, 3, side=None)
assert str(exc.value) == "expected nonempty string for keyword 'side'"
exc = raises(ValueError, a.searchsorted, 3, side='')
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -51,7 +51,9 @@
w_res = i.getitem(s)
if isinstance(w_res, boxes.W_Float64Box):
return w_res.value
- if isinstance(w_res, boxes.W_Int64Box):
+ elif isinstance(w_res, boxes.W_Int64Box):
+ return float(w_res.value)
+ elif isinstance(w_res, boxes.W_LongBox):
return float(w_res.value)
elif isinstance(w_res, boxes.W_BoolBox):
return float(w_res.value)
@@ -660,3 +662,30 @@
'raw_load': 2,
'raw_store': 1,
})
+
+ def define_searchsorted():
+ return """
+ a = [1, 4, 5, 6, 9]
+ b = |30| -> ::-1
+ c = searchsorted(a, b)
+ c -> -1
+ """
+
+ def test_searchsorted(self):
+ result = self.run("searchsorted")
+ assert result == 0
+ self.check_trace_count(6)
+ self.check_simple_loop({
+ 'float_lt': 1,
+ 'guard_false': 2,
+ 'guard_not_invalidated': 1,
+ 'guard_true': 2,
+ 'int_add': 3,
+ 'int_ge': 1,
+ 'int_lt': 2,
+ 'int_mul': 1,
+ 'int_rshift': 1,
+ 'int_sub': 1,
+ 'jump': 1,
+ 'raw_load': 1,
+ })
More information about the pypy-commit
mailing list