[pypy-commit] pypy numpy-refactor: dot

fijal noreply at buildbot.pypy.org
Wed Sep 5 18:36:14 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57155:bc065d195592
Date: 2012-09-05 18:35 +0200
http://bitbucket.org/pypy/pypy/changeset/bc065d195592/

Log:	dot

diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -25,7 +25,7 @@
         'zeros': 'interp_numarray.zeros',
         'empty': 'interp_numarray.zeros',
         'ones': 'interp_numarray.ones',
-        'dot': 'interp_numarray.dot',
+        'dot': 'interp_arrayops.dot',
         'fromstring': 'interp_support.fromstring',
         'flatiter': 'interp_flatiter.W_FlatIterator',
         'isna': 'interp_numarray.isna',
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -3,7 +3,7 @@
 from pypy.module.micronumpy import support, loop
 from pypy.module.micronumpy.base import convert_to_array
 from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
-     calculate_broadcast_strides
+     calculate_broadcast_strides, calculate_dot_strides
 from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, RecordChunk
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.rlib import jit
@@ -287,6 +287,11 @@
     def create_axis_iter(self, shape, dim):
         return AxisIterator(self, shape, dim)
 
+    def create_dot_iter(self, shape, skip):
+        r = calculate_dot_strides(self.strides, self.backstrides,
+                                  shape, skip)
+        return MultiDimViewIterator(self, self.start, r[0], r[1], shape)
+
 class ConcreteArray(BaseConcreteArray):
     def __init__(self, shape, dtype, order, strides, backstrides):
         self.shape = shape
diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -20,10 +20,6 @@
         return W_NDimArray(impl)
 
     @classmethod
-    def from_strides(cls):
-        xxx
-
-    @classmethod
     def new_slice(cls, offset, strides, backstrides, shape, parent):
         from pypy.module.micronumpy.arrayimpl import concrete
 
diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
--- a/pypy/module/micronumpy/dot.py
+++ b/pypy/module/micronumpy/dot.py
@@ -1,6 +1,5 @@
 from pypy.module.micronumpy.strides import calculate_dot_strides
 from pypy.interpreter.error import OperationError
-from pypy.module.micronumpy.interp_iter import ViewIterator
 from pypy.rlib import jit
 
 def dot_printable_location(shapelen):
@@ -15,71 +14,23 @@
 )
 
 def match_dot_shapes(space, left, right):
-    my_critical_dim_size = left.shape[-1]
-    right_critical_dim_size = right.shape[0]
+    left_shape = left.get_shape()
+    right_shape = right.get_shape()
+    my_critical_dim_size = left_shape[-1]
+    right_critical_dim_size = right_shape[0]
     right_critical_dim = 0
     out_shape = []
-    if len(right.shape) > 1:
-        right_critical_dim = len(right.shape) - 2
-        right_critical_dim_size = right.shape[right_critical_dim]
+    if len(right_shape) > 1:
+        right_critical_dim = len(right_shape) - 2
+        right_critical_dim_size = right_shape[right_critical_dim]
         assert right_critical_dim >= 0
-        out_shape += left.shape[:-1] + \
-                     right.shape[0:right_critical_dim] + \
-                     right.shape[right_critical_dim + 1:]
-    elif len(right.shape) > 0:
+        out_shape += left_shape[:-1] + \
+                     right_shape[0:right_critical_dim] + \
+                     right_shape[right_critical_dim + 1:]
+    elif len(right_shape) > 0:
         #dot does not reduce for scalars
-        out_shape += left.shape[:-1]
+        out_shape += left_shape[:-1]
     if my_critical_dim_size != right_critical_dim_size:
         raise OperationError(space.w_ValueError, space.wrap(
                                         "objects are not aligned"))
     return out_shape, right_critical_dim
-
-def multidim_dot(space, left, right, result, dtype, right_critical_dim):
-    ''' assumes left, right are concrete arrays
-    given left.shape == [3, 5, 7],
-          right.shape == [2, 7, 4]
-    then
-     result.shape == [3, 5, 2, 4]
-     broadcast shape should be [3, 5, 2, 7, 4]
-     result should skip dims 3 which is len(result_shape) - 1
-        (note that if right is 1d, result should 
-                  skip len(result_shape))
-     left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
-          except where it==(right.ndims-2)
-     right should skip 0, 1
-    '''
-    broadcast_shape = left.shape[:-1] + right.shape
-    shapelen = len(broadcast_shape)
-    left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape))
-                                         if i != right_critical_dim]
-    right_skip = range(len(left.shape) - 1)
-    result_skip = [len(result.shape) - (len(right.shape) > 1)]
-    _r = calculate_dot_strides(result.strides, result.backstrides,
-                                  broadcast_shape, result_skip)
-    outi = ViewIterator(result.start, _r[0], _r[1], broadcast_shape)
-    _r = calculate_dot_strides(left.strides, left.backstrides,
-                                  broadcast_shape, left_skip)
-    lefti = ViewIterator(left.start, _r[0], _r[1], broadcast_shape)
-    _r = calculate_dot_strides(right.strides, right.backstrides,
-                                  broadcast_shape, right_skip)
-    righti = ViewIterator(right.start, _r[0], _r[1], broadcast_shape)
-    while not outi.done():
-        dot_driver.jit_merge_point(left=left,
-                                   right=right,
-                                   shapelen=shapelen,
-                                   lefti=lefti,
-                                   righti=righti,
-                                   outi=outi,
-                                   result=result,
-                                   dtype=dtype,
-                                  )
-        lval = left.getitem(lefti.offset).convert_to(dtype) 
-        rval = right.getitem(righti.offset).convert_to(dtype) 
-        outval = result.getitem(outi.offset).convert_to(dtype) 
-        v = dtype.itemtype.mul(lval, rval)
-        value = dtype.itemtype.add(v, outval).convert_to(dtype)
-        result.setitem(outi.offset, value)
-        outi = outi.next(shapelen)
-        righti = righti.next(shapelen)
-        lefti = lefti.next(shapelen)
-    return result
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -72,3 +72,9 @@
     dtype = arr.get_dtype()
     out = W_NDimArray.from_shape(arr.get_shape(), dtype)
     return loop.where(out, arr, x, y, dtype)
+
+def dot(space, w_obj1, w_obj2):
+    w_arr = convert_to_array(space, w_obj1)
+    if w_arr.is_scalar():
+        return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
+    return w_arr.descr_dot(space, w_obj2)
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -2,13 +2,14 @@
 from pypy.interpreter.error import operationerrfmt, OperationError
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy import interp_dtype, interp_ufuncs, support
 from pypy.module.micronumpy.strides import find_shape_and_elems,\
      get_shape_from_iterable
 from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
 from pypy.module.micronumpy import loop
+from pypy.module.micronumpy.dot import match_dot_shapes
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib import jit
 from pypy.rlib.rstring import StringBuilder
@@ -94,6 +95,9 @@
     def create_axis_iter(self, shape, dim):
         return self.implementation.create_axis_iter(shape, dim)
 
+    def create_dot_iter(self, shape, skip):
+        return self.implementation.create_dot_iter(shape, skip)
+
     def is_scalar(self):
         return self.implementation.is_scalar()
 
@@ -228,6 +232,34 @@
         w_remainder = self.descr_rmod(space, w_other)
         return space.newtuple([w_quotient, w_remainder])
 
+    def descr_dot(self, space, w_other):
+        other = convert_to_array(space, w_other)
+        if other.is_scalar():
+            #Note: w_out is not modified, this is numpy compliant.
+            return self.descr_mul(space, other)
+        elif len(self.get_shape()) < 2 and len(other.get_shape()) < 2:
+            w_res = self.descr_mul(space, other)
+            return w_res.descr_sum(space, space.wrap(-1))
+        dtype = interp_ufuncs.find_binop_result_dtype(space,
+                                     self.get_dtype(), other.get_dtype())
+        if self.get_size() < 1 and other.get_size() < 1:
+            # numpy compatability
+            return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
+        # Do the dims match?
+        out_shape, other_critical_dim = match_dot_shapes(space, self, other)
+        result = W_NDimArray.from_shape(out_shape, dtype)
+        # This is the place to add fpypy and blas
+        return loop.multidim_dot(space, self, other,  result, dtype,
+                                 other_critical_dim)
+
+    def descr_var(self, space, w_axis=None):
+        return get_appbridge_cache(space).call_method(space, '_var', self,
+                                                      w_axis)
+
+    def descr_std(self, space, w_axis=None):
+        return get_appbridge_cache(space).call_method(space, '_std', self,
+                                                      w_axis)
+
     # ----------------------- reduce -------------------------------
 
     def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False):
@@ -355,9 +387,9 @@
     argmin = interp2app(W_NDimArray.descr_argmin),
     all = interp2app(W_NDimArray.descr_all),
     any = interp2app(W_NDimArray.descr_any),
-    #dot = interp2app(W_NDimArray.descr_dot),
-    #var = interp2app(W_NDimArray.descr_var),
-    #std = interp2app(W_NDimArray.descr_std),
+    dot = interp2app(W_NDimArray.descr_dot),
+    var = interp2app(W_NDimArray.descr_var),
+    std = interp2app(W_NDimArray.descr_std),
 
     copy = interp2app(W_NDimArray.descr_copy),
     reshape = interp2app(W_NDimArray.descr_reshape),
@@ -436,10 +468,6 @@
     arr.fill(one)
     return space.wrap(arr)
 
-
-def dot(space):
-    pass
-
 def isna(space):
     pass
 
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -119,3 +119,39 @@
         iter.next()
         idx += 1
     return result
+
+def multidim_dot(space, left, right, result, dtype, right_critical_dim):
+    ''' assumes left, right are concrete arrays
+    given left.shape == [3, 5, 7],
+          right.shape == [2, 7, 4]
+    then
+     result.shape == [3, 5, 2, 4]
+     broadcast shape should be [3, 5, 2, 7, 4]
+     result should skip dims 3 which is len(result_shape) - 1
+        (note that if right is 1d, result should 
+                  skip len(result_shape))
+     left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
+          except where it==(right.ndims-2)
+     right should skip 0, 1
+    '''
+    left_shape = left.get_shape()
+    right_shape = right.get_shape()
+    broadcast_shape = left_shape[:-1] + right_shape
+    left_skip = [len(left_shape) - 1 + i for i in range(len(right_shape))
+                                         if i != right_critical_dim]
+    right_skip = range(len(left_shape) - 1)
+    result_skip = [len(result.get_shape()) - (len(right_shape) > 1)]
+    outi = result.create_dot_iter(broadcast_shape, result_skip)
+    lefti = left.create_dot_iter(broadcast_shape, left_skip)
+    righti = right.create_dot_iter(broadcast_shape, right_skip)
+    while not outi.done():
+        lval = lefti.getitem().convert_to(dtype) 
+        rval = righti.getitem().convert_to(dtype) 
+        outval = outi.getitem().convert_to(dtype) 
+        v = dtype.itemtype.mul(lval, rval)
+        value = dtype.itemtype.add(v, outval).convert_to(dtype)
+        outi.setitem(value)
+        outi.next()
+        righti.next()
+        lefti.next()
+    return result


More information about the pypy-commit mailing list