[pypy-commit] pypy matrixmath-dot: passes a test, needs cleanup

mattip noreply at buildbot.pypy.org
Wed Jan 18 01:22:48 CET 2012


Author: mattip
Branch: matrixmath-dot
Changeset: r51432:f62709780578
Date: 2012-01-18 02:22 +0200
http://bitbucket.org/pypy/pypy/changeset/f62709780578/

Log:	passes a test, needs cleanup

diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -2,7 +2,7 @@
 from pypy.rlib import jit
 from pypy.rlib.objectmodel import instantiate
 from pypy.module.micronumpy.strides import calculate_broadcast_strides,\
-     calculate_slice_strides
+     calculate_slice_strides, calculate_dot_strides
 
 class BaseTransform(object):
     pass
@@ -16,6 +16,11 @@
     def __init__(self, res_shape):
         self.res_shape = res_shape
 
+class DotTransform(BaseTransform):
+    def __init__(self, res_shape, skip_dims):
+        self.res_shape = res_shape
+        self.skip_dims = skip_dims
+
 class BaseIterator(object):
     def next(self, shapelen):
         raise NotImplementedError
@@ -85,6 +90,10 @@
                                         self.strides,
                                         self.backstrides, t.chunks)
             return ViewIterator(r[1], r[2], r[3], r[0])
+        elif isinstance(t, DotTransform):
+            r = calculate_dot_strides(self.strides, self.backstrides,
+                                     t.res_shape, t.skip_dims)
+            return ViewIterator(self.offset, r[0], r[1], t.res_shape)
 
     @jit.unroll_safe
     def next(self, shapelen):
@@ -130,6 +139,7 @@
     def transform(self, arr, t):
         pass
 
+
 class AxisIterator(BaseIterator):
     def __init__(self, start, dim, shape, strides, backstrides):
         self.res_shape = shape[:]
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -3,13 +3,14 @@
 from pypy.interpreter.gateway import interp2app, NoneNotWrapped
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature
-from pypy.module.micronumpy.strides import calculate_slice_strides
+from pypy.module.micronumpy.strides import calculate_slice_strides,\
+                                           calculate_dot_strides
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
-     SkipLastAxisIterator
+     SkipLastAxisIterator, ViewIterator
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
@@ -211,6 +212,28 @@
                 n_old_elems_to_use *= old_shape[oldI]
     return new_strides
 
+def match_dot_shapes(space, self, other):
+    my_critical_dim_size = self.shape[-1]
+    other_critical_dim_size = other.shape[0]
+    other_critical_dim = 0
+    other_critical_dim_stride = other.strides[0]
+    out_shape = []
+    if len(other.shape) > 1:
+        other_critical_dim = len(other.shape) - 2
+        other_critical_dim_size = other.shape[other_critical_dim]
+        other_critical_dim_stride = other.strides[other_critical_dim]
+        assert other_critical_dim >= 0
+        out_shape += self.shape[:-1] + \
+                     other.shape[0:other_critical_dim] + \
+                     other.shape[other_critical_dim + 1:]
+    elif len(other.shape) > 0:
+        #dot does not reduce for scalars
+        out_shape += self.shape[:-1]
+    if my_critical_dim_size != other_critical_dim_size:
+        raise OperationError(space.w_ValueError, space.wrap(
+                                        "objects are not aligned"))
+    return out_shape, other_critical_dim
+
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "shape", 'size']
 
@@ -384,70 +407,62 @@
     the second-to-last of `b`::
 
         dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])'''
-        w_other = convert_to_array(space, w_other)
-        if isinstance(w_other, Scalar):
-            return self.descr_mul(space, w_other)
-        elif len(self.shape) < 2 and len(w_other.shape) < 2:
-            w_res = self.descr_mul(space, w_other)
+        other = convert_to_array(space, w_other)
+        if isinstance(other, Scalar):
+            return self.descr_mul(space, other)
+        elif len(self.shape) < 2 and len(other.shape) < 2:
+            w_res = self.descr_mul(space, other)
             assert isinstance(w_res, BaseArray)
             return w_res.descr_sum(space, space.wrap(-1))
         dtype = interp_ufuncs.find_binop_result_dtype(space,
-                                     self.find_dtype(), w_other.find_dtype())
-        if self.size < 1 and w_other.size < 1:
+                                     self.find_dtype(), other.find_dtype())
+        if self.size < 1 and other.size < 1:
             #numpy compatability
             return scalar_w(space, dtype, space.wrap(0))
         #Do the dims match?
-        my_critical_dim_size = self.shape[-1]
-        other_critical_dim_size = w_other.shape[0]
-        other_critical_dim = 0
-        other_critical_dim_stride = w_other.strides[0]
-        out_shape = []
-        if len(w_other.shape) > 1:
-            other_critical_dim = len(w_other.shape) - 2
-            other_critical_dim_size = w_other.shape[other_critical_dim]
-            other_critical_dim_stride = w_other.strides[other_critical_dim]
-            assert other_critical_dim >= 0
-            out_shape += self.shape[:-1] + \
-                         w_other.shape[0:other_critical_dim] + \
-                         w_other.shape[other_critical_dim + 1:]
-        elif len(w_other.shape) > 0:
-            #dot does not reduce for scalars
-            out_shape += self.shape[:-1]
-        if my_critical_dim_size != other_critical_dim_size:
-            raise OperationError(space.w_ValueError, space.wrap(
-                                            "objects are not aligned"))
+        out_shape, other_critical_dim = match_dot_shapes(space, self, other)
         out_size = 1
-        for os in out_shape:
-            out_size *= os
-        out_ndims = len(out_shape)
-        # TODO: what should the order be? C or F?
-        arr = W_NDimArray(out_size, out_shape, dtype=dtype)
-        # TODO: this is all a bogus mess of previous work, 
-        # rework within the context of transformations
-        '''
-        out_iter = ViewIterator(arr.start, arr.strides, arr.backstrides, arr.shape)
-        # TODO: invalidate self, w_other with arr ?
-        while not out_iter.done():
-            my_index = self.start
-            other_index = w_other.start
-            i = 0
-            while i < len(self.shape) - 1:
-                my_index += out_iter.indices[i] * self.strides[i]
-                i += 1
-            for j in range(len(w_other.shape) - 2):
-                other_index += out_iter.indices[i] * w_other.strides[j]
-            other_index += out_iter.indices[-1] * w_other.strides[-1]
-            w_ssd = space.newlist([space.wrap(my_index),
-                                   space.wrap(len(self.shape) - 1)])
-            w_osd = space.newlist([space.wrap(other_index),
-                                   space.wrap(other_critical_dim)])
-            w_res = self.descr_mul(space, w_other)
-            assert isinstance(w_res, BaseArray)
-            value = w_res.descr_sum(space)
-            arr.setitem(out_iter.get_offset(), value)
-            out_iter = out_iter.next(out_ndims)
-        '''
-        return arr
+        for o in out_shape:
+            out_size *= o
+        result = W_NDimArray(out_size, out_shape, dtype)
+        # given a.shape == [3, 5, 7],
+        #       b.shape == [2, 7, 4]
+        #  result.shape == [3, 5, 2, 4]
+        # all iterators shapes should be [3, 5, 2, 7, 4]
+        # result should skip dims 3 which is results.ndims - 1
+        # a should skip 2, 4 which is a.ndims-1 + range(b.ndims) 
+        #       except where it==(b.ndims-2)
+        # b should skip 0, 1
+        mul = interp_ufuncs.get(space).multiply.func
+        add = interp_ufuncs.get(space).add.func
+        broadcast_shape = self.shape[:-1] + other.shape
+        #Aww, cmon, this is the product of a warped mind.
+        left_skip = [len(self.shape) - 1 + i for i in range(len(other.shape)) if i != other_critical_dim]
+        right_skip = range(len(self.shape) - 1)
+        arr = DotArray(mul, 'DotName', out_shape, dtype, self, other,
+                                        left_skip, right_skip)
+        arr.broadcast_shape = broadcast_shape
+        arr.result_skip = [len(out_shape) - 1]
+        #Make this lazy someday...
+        sig = signature.find_sig(signature.DotSignature(mul, 'dot', dtype,
+                                  self.create_sig(), other.create_sig()), arr)
+        assert isinstance(sig, signature.DotSignature)
+        self.do_dot_loop(sig, result, arr, add)
+        return result
+
+    def do_dot_loop(self, sig, result, arr, add):
+        frame = sig.create_frame(arr)
+        shapelen = len(arr.broadcast_shape)
+        _r = calculate_dot_strides(result.strides, result.backstrides,
+                                      arr.broadcast_shape, arr.result_skip)
+        ri = ViewIterator(0, _r[0], _r[1], arr.broadcast_shape)
+        while not frame.done():
+            v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
+            z = result.getitem(ri.offset)
+            value = add(sig.calc_dtype, v, result.getitem(ri.offset))
+            result.setitem(ri.offset, value)
+            frame.next(shapelen)
+            ri = ri.next(shapelen)
 
     def get_concrete(self):
         raise NotImplementedError
@@ -919,6 +934,23 @@
                        left, right)
         self.dim = dim
 
+class DotArray(Call2):
+    """ NOTE: this is only used as a container, you should never
+    encounter such things in the wild. Remove this comment
+    when we'll make Dot lazy
+    """
+    _immutable_fields_ = ['left', 'right']
+    
+    def __init__(self, ufunc, name, shape, dtype, left, right, left_skip, right_skip):
+        Call2.__init__(self, ufunc, name, shape, dtype, dtype,
+                       left, right)
+        self.left_skip = left_skip
+        self.right_skip = right_skip
+    def create_sig(self):
+        #if self.forced_result is not None:
+        #    return self.forced_result.create_sig()
+        assert NotImplementedError 
+
 class ConcreteArray(BaseArray):
     """ An array that have actual storage, whether owned or not
     """
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -192,17 +192,17 @@
                                               sig=sig,
                                               identity=identity,
                                               shapelen=shapelen, arr=arr)
-            iter = frame.get_final_iter()
+            iterator = frame.get_final_iter()
             v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
-            if iter.first_line:
+            if iterator.first_line:
                 if identity is not None:
                     value = self.func(sig.calc_dtype, identity, v)
                 else:
                     value = v
             else:
-                cur = arr.left.getitem(iter.offset)
+                cur = arr.left.getitem(iterator.offset)
                 value = self.func(sig.calc_dtype, cur, v)
-            arr.left.setitem(iter.offset, value)
+            arr.left.setitem(iterator.offset, value)
             frame.next(shapelen)
 
     def reduce_loop(self, shapelen, sig, frame, value, obj, dtype):
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -2,7 +2,7 @@
 from pypy.rlib.rarithmetic import intmask
 from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
      ConstantIterator, AxisIterator, ViewTransform,\
-     BroadcastTransform
+     BroadcastTransform, DotTransform
 from pypy.rlib.jit import hint, unroll_safe, promote
 
 """ Signature specifies both the numpy expression that has been constructed
@@ -331,7 +331,6 @@
         assert isinstance(arr, Call2)
         lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
         rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
-        
         return self.binfunc(self.calc_dtype, lhs, rhs)
 
     def debug_repr(self):
@@ -450,3 +449,21 @@
     
     def debug_repr(self):
         return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr())
+
+class DotSignature(Call2):
+    def _invent_numbering(self, cache, allnumbers):
+        self.left._invent_numbering(new_cache(), allnumbers)
+        self.right._invent_numbering(new_cache(), allnumbers)
+
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import DotArray
+
+        assert isinstance(arr, DotArray)
+        rtransforms = transforms + [DotTransform(arr.broadcast_shape, arr.right_skip)]
+        ltransforms = transforms + [DotTransform(arr.broadcast_shape, arr.left_skip)]
+        self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
+
+    def debug_repr(self):
+        return 'DotSig(%s, %s %s)' % (self.name, self.right.debug_repr(),
+						 self.left.debug_repr())
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -37,3 +37,17 @@
     rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
     rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
     return rstrides, rbackstrides
+
+def calculate_dot_strides(strides, backstrides, res_shape, skip_dims):
+    rstrides = []
+    rbackstrides = []
+    j=0
+    for i in range(len(res_shape)):
+        if i in skip_dims:
+            rstrides.append(0)
+            rbackstrides.append(0)
+        else:
+            rstrides.append(strides[j])
+            rbackstrides.append(backstrides[j])
+            j += 1
+    return rstrides, rbackstrides
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -867,7 +867,7 @@
         assert c.any() == False
 
     def test_dot(self):
-        from _numpypy import array, dot
+        from _numpypy import array, dot, arange
         a = array(range(5))
         assert a.dot(a) == 30.0
 
@@ -876,13 +876,12 @@
         assert dot(range(5), range(5)) == 30
         assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
 
-        a = array([range(4), range(4, 8), range(8, 12)])
-        b = array([range(3), range(3, 6), range(6, 9), range(9, 12)])
+        a = arange(12).reshape(3, 4)
+        b = arange(12).reshape(4, 3)
         c = a.dot(b)
         assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
 
-        a = array([[range(4), range(4, 8), range(8, 12)],
-                   [range(12, 16), range(16, 20), range(20, 24)]])
+        a = arange(24).reshape(2, 3, 4)
         raises(ValueError, "a.dot(a)")
         b = a[0, :, :].T
         #Superfluous shape test makes the intention of the test clearer


More information about the pypy-commit mailing list