[pypy-commit] pypy numpy-refactor: dot
fijal
noreply at buildbot.pypy.org
Wed Sep 5 18:36:14 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57155:bc065d195592
Date: 2012-09-05 18:35 +0200
http://bitbucket.org/pypy/pypy/changeset/bc065d195592/
Log: dot
diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -25,7 +25,7 @@
'zeros': 'interp_numarray.zeros',
'empty': 'interp_numarray.zeros',
'ones': 'interp_numarray.ones',
- 'dot': 'interp_numarray.dot',
+ 'dot': 'interp_arrayops.dot',
'fromstring': 'interp_support.fromstring',
'flatiter': 'interp_flatiter.W_FlatIterator',
'isna': 'interp_numarray.isna',
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -3,7 +3,7 @@
from pypy.module.micronumpy import support, loop
from pypy.module.micronumpy.base import convert_to_array
from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
- calculate_broadcast_strides
+ calculate_broadcast_strides, calculate_dot_strides
from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, RecordChunk
from pypy.interpreter.error import OperationError, operationerrfmt
from pypy.rlib import jit
@@ -287,6 +287,11 @@
def create_axis_iter(self, shape, dim):
return AxisIterator(self, shape, dim)
+ def create_dot_iter(self, shape, skip):
+ r = calculate_dot_strides(self.strides, self.backstrides,
+ shape, skip)
+ return MultiDimViewIterator(self, self.start, r[0], r[1], shape)
+
class ConcreteArray(BaseConcreteArray):
def __init__(self, shape, dtype, order, strides, backstrides):
self.shape = shape
diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -20,10 +20,6 @@
return W_NDimArray(impl)
@classmethod
- def from_strides(cls):
- xxx
-
- @classmethod
def new_slice(cls, offset, strides, backstrides, shape, parent):
from pypy.module.micronumpy.arrayimpl import concrete
diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
--- a/pypy/module/micronumpy/dot.py
+++ b/pypy/module/micronumpy/dot.py
@@ -1,6 +1,5 @@
from pypy.module.micronumpy.strides import calculate_dot_strides
from pypy.interpreter.error import OperationError
-from pypy.module.micronumpy.interp_iter import ViewIterator
from pypy.rlib import jit
def dot_printable_location(shapelen):
@@ -15,71 +14,23 @@
)
def match_dot_shapes(space, left, right):
- my_critical_dim_size = left.shape[-1]
- right_critical_dim_size = right.shape[0]
+ left_shape = left.get_shape()
+ right_shape = right.get_shape()
+ my_critical_dim_size = left_shape[-1]
+ right_critical_dim_size = right_shape[0]
right_critical_dim = 0
out_shape = []
- if len(right.shape) > 1:
- right_critical_dim = len(right.shape) - 2
- right_critical_dim_size = right.shape[right_critical_dim]
+ if len(right_shape) > 1:
+ right_critical_dim = len(right_shape) - 2
+ right_critical_dim_size = right_shape[right_critical_dim]
assert right_critical_dim >= 0
- out_shape += left.shape[:-1] + \
- right.shape[0:right_critical_dim] + \
- right.shape[right_critical_dim + 1:]
- elif len(right.shape) > 0:
+ out_shape += left_shape[:-1] + \
+ right_shape[0:right_critical_dim] + \
+ right_shape[right_critical_dim + 1:]
+ elif len(right_shape) > 0:
#dot does not reduce for scalars
- out_shape += left.shape[:-1]
+ out_shape += left_shape[:-1]
if my_critical_dim_size != right_critical_dim_size:
raise OperationError(space.w_ValueError, space.wrap(
"objects are not aligned"))
return out_shape, right_critical_dim
-
-def multidim_dot(space, left, right, result, dtype, right_critical_dim):
- ''' assumes left, right are concrete arrays
- given left.shape == [3, 5, 7],
- right.shape == [2, 7, 4]
- then
- result.shape == [3, 5, 2, 4]
- broadcast shape should be [3, 5, 2, 7, 4]
- result should skip dims 3 which is len(result_shape) - 1
- (note that if right is 1d, result should
- skip len(result_shape))
- left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
- except where it==(right.ndims-2)
- right should skip 0, 1
- '''
- broadcast_shape = left.shape[:-1] + right.shape
- shapelen = len(broadcast_shape)
- left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape))
- if i != right_critical_dim]
- right_skip = range(len(left.shape) - 1)
- result_skip = [len(result.shape) - (len(right.shape) > 1)]
- _r = calculate_dot_strides(result.strides, result.backstrides,
- broadcast_shape, result_skip)
- outi = ViewIterator(result.start, _r[0], _r[1], broadcast_shape)
- _r = calculate_dot_strides(left.strides, left.backstrides,
- broadcast_shape, left_skip)
- lefti = ViewIterator(left.start, _r[0], _r[1], broadcast_shape)
- _r = calculate_dot_strides(right.strides, right.backstrides,
- broadcast_shape, right_skip)
- righti = ViewIterator(right.start, _r[0], _r[1], broadcast_shape)
- while not outi.done():
- dot_driver.jit_merge_point(left=left,
- right=right,
- shapelen=shapelen,
- lefti=lefti,
- righti=righti,
- outi=outi,
- result=result,
- dtype=dtype,
- )
- lval = left.getitem(lefti.offset).convert_to(dtype)
- rval = right.getitem(righti.offset).convert_to(dtype)
- outval = result.getitem(outi.offset).convert_to(dtype)
- v = dtype.itemtype.mul(lval, rval)
- value = dtype.itemtype.add(v, outval).convert_to(dtype)
- result.setitem(outi.offset, value)
- outi = outi.next(shapelen)
- righti = righti.next(shapelen)
- lefti = lefti.next(shapelen)
- return result
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -72,3 +72,9 @@
dtype = arr.get_dtype()
out = W_NDimArray.from_shape(arr.get_shape(), dtype)
return loop.where(out, arr, x, y, dtype)
+
+def dot(space, w_obj1, w_obj2):
+ w_arr = convert_to_array(space, w_obj1)
+ if w_arr.is_scalar():
+ return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
+ return w_arr.descr_dot(space, w_obj2)
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -2,13 +2,14 @@
from pypy.interpreter.error import operationerrfmt, OperationError
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.interpreter.gateway import interp2app, unwrap_spec
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy import interp_dtype, interp_ufuncs, support
from pypy.module.micronumpy.strides import find_shape_and_elems,\
get_shape_from_iterable
from pypy.module.micronumpy.interp_support import unwrap_axis_arg
from pypy.module.micronumpy.appbridge import get_appbridge_cache
from pypy.module.micronumpy import loop
+from pypy.module.micronumpy.dot import match_dot_shapes
from pypy.tool.sourcetools import func_with_new_name
from pypy.rlib import jit
from pypy.rlib.rstring import StringBuilder
@@ -94,6 +95,9 @@
def create_axis_iter(self, shape, dim):
return self.implementation.create_axis_iter(shape, dim)
+ def create_dot_iter(self, shape, skip):
+ return self.implementation.create_dot_iter(shape, skip)
+
def is_scalar(self):
return self.implementation.is_scalar()
@@ -228,6 +232,34 @@
w_remainder = self.descr_rmod(space, w_other)
return space.newtuple([w_quotient, w_remainder])
+ def descr_dot(self, space, w_other):
+ other = convert_to_array(space, w_other)
+ if other.is_scalar():
+ #Note: w_out is not modified, this is numpy compliant.
+ return self.descr_mul(space, other)
+ elif len(self.get_shape()) < 2 and len(other.get_shape()) < 2:
+ w_res = self.descr_mul(space, other)
+ return w_res.descr_sum(space, space.wrap(-1))
+ dtype = interp_ufuncs.find_binop_result_dtype(space,
+ self.get_dtype(), other.get_dtype())
+ if self.get_size() < 1 and other.get_size() < 1:
+ # numpy compatability
+ return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
+ # Do the dims match?
+ out_shape, other_critical_dim = match_dot_shapes(space, self, other)
+ result = W_NDimArray.from_shape(out_shape, dtype)
+ # This is the place to add fpypy and blas
+ return loop.multidim_dot(space, self, other, result, dtype,
+ other_critical_dim)
+
+ def descr_var(self, space, w_axis=None):
+ return get_appbridge_cache(space).call_method(space, '_var', self,
+ w_axis)
+
+ def descr_std(self, space, w_axis=None):
+ return get_appbridge_cache(space).call_method(space, '_std', self,
+ w_axis)
+
# ----------------------- reduce -------------------------------
def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False):
@@ -355,9 +387,9 @@
argmin = interp2app(W_NDimArray.descr_argmin),
all = interp2app(W_NDimArray.descr_all),
any = interp2app(W_NDimArray.descr_any),
- #dot = interp2app(W_NDimArray.descr_dot),
- #var = interp2app(W_NDimArray.descr_var),
- #std = interp2app(W_NDimArray.descr_std),
+ dot = interp2app(W_NDimArray.descr_dot),
+ var = interp2app(W_NDimArray.descr_var),
+ std = interp2app(W_NDimArray.descr_std),
copy = interp2app(W_NDimArray.descr_copy),
reshape = interp2app(W_NDimArray.descr_reshape),
@@ -436,10 +468,6 @@
arr.fill(one)
return space.wrap(arr)
-
-def dot(space):
- pass
-
def isna(space):
pass
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -119,3 +119,39 @@
iter.next()
idx += 1
return result
+
+def multidim_dot(space, left, right, result, dtype, right_critical_dim):
+ ''' assumes left, right are concrete arrays
+ given left.shape == [3, 5, 7],
+ right.shape == [2, 7, 4]
+ then
+ result.shape == [3, 5, 2, 4]
+ broadcast shape should be [3, 5, 2, 7, 4]
+ result should skip dims 3 which is len(result_shape) - 1
+ (note that if right is 1d, result should
+ skip len(result_shape))
+ left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
+ except where it==(right.ndims-2)
+ right should skip 0, 1
+ '''
+ left_shape = left.get_shape()
+ right_shape = right.get_shape()
+ broadcast_shape = left_shape[:-1] + right_shape
+ left_skip = [len(left_shape) - 1 + i for i in range(len(right_shape))
+ if i != right_critical_dim]
+ right_skip = range(len(left_shape) - 1)
+ result_skip = [len(result.get_shape()) - (len(right_shape) > 1)]
+ outi = result.create_dot_iter(broadcast_shape, result_skip)
+ lefti = left.create_dot_iter(broadcast_shape, left_skip)
+ righti = right.create_dot_iter(broadcast_shape, right_skip)
+ while not outi.done():
+ lval = lefti.getitem().convert_to(dtype)
+ rval = righti.getitem().convert_to(dtype)
+ outval = outi.getitem().convert_to(dtype)
+ v = dtype.itemtype.mul(lval, rval)
+ value = dtype.itemtype.add(v, outval).convert_to(dtype)
+ outi.setitem(value)
+ outi.next()
+ righti.next()
+ lefti.next()
+ return result
More information about the pypy-commit
mailing list