[pypy-commit] pypy default: implement searchsorted in rpython with jitdriver

bdkearns noreply at buildbot.pypy.org
Fri Oct 10 06:25:45 CEST 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r73877:0462e4a83ff1
Date: 2014-10-09 18:03 -0400
http://bitbucket.org/pypy/pypy/changeset/0462e4a83ff1/

Log:	implement searchsorted in rpython with jitdriver

diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -36,7 +36,7 @@
 SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
                         "unegative", "flat", "tostring","count_nonzero",
                         "argsort"]
-TWO_ARG_FUNCTIONS = ["dot", 'take']
+TWO_ARG_FUNCTIONS = ["dot", 'take', 'searchsorted']
 TWO_ARG_FUNCTIONS_OR_NONE = ['view', 'astype']
 THREE_ARG_FUNCTIONS = ['where']
 
@@ -109,6 +109,9 @@
             if stop < 0:
                 stop += size + 1
             if step < 0:
+                start, stop = stop, start
+                start -= 1
+                stop -= 1
                 lgt = (stop - start + 1) / step + 1
             else:
                 lgt = (stop - start - 1) / step + 1
@@ -475,7 +478,6 @@
 
 class SliceConstant(Node):
     def __init__(self, start, stop, step):
-        # no negative support for now
         self.start = start
         self.stop = stop
         self.step = step
@@ -582,6 +584,9 @@
                 w_res = arr.descr_dot(interp.space, arg)
             elif self.name == 'take':
                 w_res = arr.descr_take(interp.space, arg)
+            elif self.name == "searchsorted":
+                w_res = arr.descr_searchsorted(interp.space, arg,
+                                               interp.space.wrap('left'))
             else:
                 assert False # unreachable code
         elif self.name in THREE_ARG_FUNCTIONS:
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -700,3 +700,43 @@
         out_iter.setitem(out_state, arr.getitem_index(space, indexes))
         iter.next()
         out_state = out_iter.next(out_state)
+
+def _new_binsearch(side, op_name):
+    binsearch_driver = jit.JitDriver(name='numpy_binsearch_' + side,
+                                     greens=['dtype'],
+                                     reds='auto')
+
+    def binsearch(space, arr, key, ret):
+        assert len(arr.get_shape()) == 1
+        dtype = key.get_dtype()
+        op = getattr(dtype.itemtype, op_name)
+        key_iter, key_state = key.create_iter()
+        ret_iter, ret_state = ret.create_iter()
+        ret_iter.track_index = False
+        size = arr.get_size()
+        min_idx = 0
+        max_idx = size
+        last_key_val = key_iter.getitem(key_state)
+        while not key_iter.done(key_state):
+            key_val = key_iter.getitem(key_state)
+            if dtype.itemtype.lt(last_key_val, key_val):
+                max_idx = size
+            else:
+                min_idx = 0
+                max_idx = max_idx + 1 if max_idx < size else size
+            last_key_val = key_val
+            while min_idx < max_idx:
+                binsearch_driver.jit_merge_point(dtype=dtype)
+                mid_idx = min_idx + ((max_idx - min_idx) >> 1)
+                mid_val = arr.getitem(space, [mid_idx]).convert_to(space, dtype)
+                if op(mid_val, key_val):
+                    min_idx = mid_idx + 1
+                else:
+                    max_idx = mid_idx
+            ret_iter.setitem(ret_state, ret.get_dtype().box(min_idx))
+            ret_state = ret_iter.next(ret_state)
+            key_state = key_iter.next(key_state)
+    return binsearch
+
+binsearch_left = _new_binsearch('left', 'lt')
+binsearch_right = _new_binsearch('right', 'le')
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -20,7 +20,6 @@
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
 from pypy.module.micronumpy.strides import get_shape_from_iterable, \
     shape_agreement, shape_agreement_multiple
-from .selection import app_searchsort
 
 
 def _match_dot_shapes(space, left, right):
@@ -740,7 +739,11 @@
         v = convert_to_array(space, w_v)
         ret = W_NDimArray.from_shape(
             space, v.get_shape(), descriptor.get_dtype_cache(space).w_longdtype)
-        app_searchsort(space, self, v, space.wrap(side), ret)
+        if side == NPY.SEARCHLEFT:
+            binsearch = loop.binsearch_left
+        else:
+            binsearch = loop.binsearch_right
+        binsearch(space, self, v, ret)
         if ret.is_scalar():
             return ret.get_scalar_value()
         return ret
diff --git a/pypy/module/micronumpy/selection.py b/pypy/module/micronumpy/selection.py
--- a/pypy/module/micronumpy/selection.py
+++ b/pypy/module/micronumpy/selection.py
@@ -1,5 +1,4 @@
 from pypy.interpreter.error import oefmt
-from pypy.interpreter.gateway import applevel
 from rpython.rlib.listsort import make_timsort_class
 from rpython.rlib.objectmodel import specialize
 from rpython.rlib.rarithmetic import widen
@@ -354,39 +353,3 @@
                 cache[cls] = make_sort_function(space, cls, it)
         self.cache = cache
         self._lookup = specialize.memo()(lambda tp: cache[tp[0]])
-
-
-app_searchsort = applevel(r"""
-    import operator
-
-    def searchsort(arr, val, side, res):
-        val = val.flat
-        res = res.flat
-        if side == 0:
-            op = operator.lt
-        else:
-            op = operator.le
-
-        size = arr.size
-        imin = 0
-        imax = size
-        try:
-            last = val[0]
-        except IndexError:
-            return
-        for i in xrange(len(val)):
-            key = val[i]
-            if last < key:
-                imax = size
-            else:
-                imin = 0
-                imax = imax + 1 if imax < size else size
-            last = key
-            while imin < imax:
-                imid = imin + ((imax - imin) >> 1)
-                if op(arr[imid], key):
-                    imin = imid + 1
-                else:
-                    imax = imid
-            res[i] = imin
-""", filename=__file__).interphook('searchsort')
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -330,3 +330,12 @@
         results = interp.results[0]
         assert isinstance(results, W_NDimArray)
         assert results.get_dtype().is_int()
+
+    def test_searchsorted(self):
+        interp = self.run('''
+        a = [1, 4, 5, 6, 9]
+        b = |30| -> ::-1
+        c = searchsorted(a, b)
+        c -> -1
+        ''')
+        assert interp.results[0].value == 0
diff --git a/pypy/module/micronumpy/test/test_selection.py b/pypy/module/micronumpy/test/test_selection.py
--- a/pypy/module/micronumpy/test/test_selection.py
+++ b/pypy/module/micronumpy/test/test_selection.py
@@ -382,6 +382,9 @@
         assert ret == 3
         assert isinstance(ret, np.generic)
 
+        assert a.searchsorted(3.1) == 3
+        assert a.searchsorted(3.9) == 3
+
         exc = raises(ValueError, a.searchsorted, 3, side=None)
         assert str(exc.value) == "expected nonempty string for keyword 'side'"
         exc = raises(ValueError, a.searchsorted, 3, side='')
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -51,7 +51,9 @@
                 w_res = i.getitem(s)
             if isinstance(w_res, boxes.W_Float64Box):
                 return w_res.value
-            if isinstance(w_res, boxes.W_Int64Box):
+            elif isinstance(w_res, boxes.W_Int64Box):
+                return float(w_res.value)
+            elif isinstance(w_res, boxes.W_LongBox):
                 return float(w_res.value)
             elif isinstance(w_res, boxes.W_BoolBox):
                 return float(w_res.value)
@@ -660,3 +662,30 @@
             'raw_load': 2,
             'raw_store': 1,
         })
+
+    def define_searchsorted():
+        return """
+        a = [1, 4, 5, 6, 9]
+        b = |30| -> ::-1
+        c = searchsorted(a, b)
+        c -> -1
+        """
+
+    def test_searchsorted(self):
+        result = self.run("searchsorted")
+        assert result == 0
+        self.check_trace_count(6)
+        self.check_simple_loop({
+            'float_lt': 1,
+            'guard_false': 2,
+            'guard_not_invalidated': 1,
+            'guard_true': 2,
+            'int_add': 3,
+            'int_ge': 1,
+            'int_lt': 2,
+            'int_mul': 1,
+            'int_rshift': 1,
+            'int_sub': 1,
+            'jump': 1,
+            'raw_load': 1,
+        })


More information about the pypy-commit mailing list