[pypy-commit] pypy matrixmath-dot: whoops

mattip noreply at buildbot.pypy.org
Fri Jan 20 09:23:07 CET 2012


Author: mattip
Branch: matrixmath-dot
Changeset: r51507:600fcfb76aab
Date: 2012-01-20 10:22 +0200
http://bitbucket.org/pypy/pypy/changeset/600fcfb76aab/

Log:	whoops

diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
--- a/pypy/module/micronumpy/dot.py
+++ b/pypy/module/micronumpy/dot.py
@@ -2,8 +2,17 @@
 from pypy.module.micronumpy.strides import calculate_dot_strides
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.module.micronumpy.interp_iter import ViewIterator
+from pypy.module.micronumpy.signature import new_printable_location
+from pypy.rlib import jit
 
 
+dot_driver = jit.JitDriver(
+    greens=['shapelen', 'left', 'right'],
+    reds=['lefti', 'righti', 'outi', 'result'],
+    get_printable_location=new_printable_location('dot'),
+    name='dot',
+)
+
 def match_dot_shapes(space, left, right):
     my_critical_dim_size = left.shape[-1]
     right_critical_dim_size = right.shape[0]
@@ -27,6 +36,7 @@
     return out_shape, right_critical_dim
 
 
+ at jit.unroll_safe
 def multidim_dot(space, left, right, result, dtype, right_critical_dim):
     ''' assumes left, right are concrete arrays
     given left.shape == [3, 5, 7],
@@ -56,6 +66,14 @@
                                   broadcast_shape, right_skip)
     righti = ViewIterator(0, _r[0], _r[1], broadcast_shape)
     while not outi.done():
+        dot_driver.jit_merge_point(left=left,
+                                   right=right,
+                                   shape_len=shape_len,
+                                   lefti=lefti,
+                                   righti=righti,
+                                   outi=outi,
+                                   result=result,
+                                  )
         v = mul(dtype, left.getitem(lefti.offset),
                        right.getitem(righti.offset))
         value = add(dtype, v, result.getitem(outi.offset))
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -10,7 +10,7 @@
 from pypy.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
      SkipLastAxisIterator
-from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes, dot_docstring
+from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes
 
 
 numpy_driver = jit.JitDriver(
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -368,6 +368,20 @@
                                 'int_ge': 1, 'guard_false': 1, 'jump': 1,
                                 'arraylen_gc': 1})
 
+    def ddefine_dot():
+        return """
+        a = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
+        b=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
+        c = a.dot(b)
+        c -> 1 -> 2
+        """
+
+    def test_dot(self):
+        py.test.skip("not yet")
+        result = self.run("dot")
+        assert result == 184
+        self.check_simple_loop({})
+
 
 class TestNumpyOld(LLJitMixin):
     def setup_class(cls):


More information about the pypy-commit mailing list