[pypy-svn] r66429 - in pypy/branch/pyjitpl5/pypy/module/micronumpy: . test
fijal at codespeak.net
fijal at codespeak.net
Mon Jul 20 11:00:48 CEST 2009
Author: fijal
Date: Mon Jul 20 11:00:47 2009
New Revision: 66429
Modified:
pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py
pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py
Log:
Implement some basics of multi-dim arrays
Modified: pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py (original)
+++ pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py Mon Jul 20 11:00:47 2009
@@ -5,13 +5,15 @@
from pypy.interpreter.gateway import interp2app, NoneNotWrapped
from pypy.rlib.debug import make_sure_not_resized
-class NumArray(Wrappable):
+class BaseNumArray(Wrappable):
+ pass
+
+class NumArray(BaseNumArray):
def __init__(self, space, dim, dtype):
self.dim = dim
self.space = space
# ignore dtype for now
- assert len(dim) == 1
- self.storage = [0] * dim[0]
+ self.storage = [0] * dim
make_sure_not_resized(self.storage)
def descr_getitem(self, index):
@@ -44,11 +46,70 @@
__len__ = interp2app(NumArray.descr_len),
)
+def compute_pos(space, indexes, dim):
+ current = 1
+ pos = 0
+ for i in range(len(indexes)):
+ index = indexes[i]
+ d = dim[i]
+ if index >= d or index <= -d - 1:
+ raise OperationError(space.w_IndexError,
+ space.wrap("invalid index"))
+ if index < 0:
+ index = d + index
+ pos += index * current
+ current *= d
+ return pos
+
+class MultiDimArray(BaseNumArray):
+ def __init__(self, space, dim, dtype):
+ self.dim = dim
+ self.space = space
+ # ignore dtype for now
+ size = 1
+ for el in dim:
+ size *= el
+ self.storage = [0] * size
+ make_sure_not_resized(self.storage)
+
+ def _unpack_indexes(self, space, w_index):
+ indexes = [space.int_w(w_i) for w_i in space.viewiterable(w_index)]
+ if len(indexes) != len(self.dim):
+ raise OperationError(space.w_IndexError, space.wrap(
+ 'Wrong index'))
+ return indexes
+
+ def descr_getitem(self, w_index):
+ space = self.space
+ indexes = self._unpack_indexes(space, w_index)
+ pos = compute_pos(space, indexes, self.dim)
+ return space.wrap(self.storage[pos])
+ descr_getitem.unwrap_spec = ['self', W_Root]
+
+ def descr_setitem(self, w_index, value):
+ space = self.space
+ indexes = self._unpack_indexes(space, w_index)
+ pos = compute_pos(space, indexes, self.dim)
+ self.storage[pos] = value
+ return space.w_None
+ descr_setitem.unwrap_spec = ['self', W_Root, int]
+
+ def descr_len(self):
+ return self.space.wrap(self.dim[0])
+ descr_len.unwrap_spec = ['self']
+
+MultiDimArray.typedef = TypeDef(
+ 'NumArray',
+ __getitem__ = interp2app(MultiDimArray.descr_getitem),
+ __setitem__ = interp2app(MultiDimArray.descr_setitem),
+ __len__ = interp2app(MultiDimArray.descr_len),
+)
+
def unpack_dim(space, w_dim):
if space.is_true(space.isinstance(w_dim, space.w_int)):
return [space.int_w(w_dim)]
- else:
- raise NotImplementedError
+ dim_w = space.viewiterable(w_dim)
+ return [space.int_w(w_i) for w_i in dim_w]
def unpack_dtype(space, w_dtype):
if space.is_w(w_dtype, space.w_int):
@@ -59,5 +120,8 @@
def zeros(space, w_dim, w_dtype):
dim = unpack_dim(space, w_dim)
dtype = unpack_dtype(space, w_dtype)
- return space.wrap(NumArray(space, dim, dtype))
+ if len(dim) == 1:
+ return space.wrap(NumArray(space, dim[0], dtype))
+ else:
+ return space.wrap(MultiDimArray(space, dim, dtype))
zeros.unwrap_spec = [ObjSpace, W_Root, W_Root]
Modified: pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py (original)
+++ pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py Mon Jul 20 11:00:47 2009
@@ -39,3 +39,27 @@
assert x[4] == 0
assert len(x) == 5
raises(ValueError, minimum, ar, zeros(3, dtype=int))
+
+class AppTestMultiDim(object):
+ def setup_class(cls):
+ cls.space = gettestobjspace(usemodules=('micronumpy',))
+
+ def test_multidim(self):
+ from numpy import zeros
+ ar = zeros((3, 3), dtype=int)
+ assert ar[0, 2] == 0
+ raises(IndexError, ar.__getitem__, (3, 0))
+ assert ar[-2, 1] == 0
+
+ def test_multidim_getset(self):
+ from numpy import zeros
+ ar = zeros((3, 3, 3), dtype=int)
+ ar[1, 2, 1] = 3
+ assert ar[1, 2, 1] == 3
+ assert ar[-2, 2, 1] == 3
+ assert ar[2, 2, 1] == 0
+ assert ar[-2, 2, -2] == 3
+
+ def test_len(self):
+ from numpy import zeros
+ assert len(zeros((3, 2, 1), dtype=int)) == 3
More information about the Pypy-commit
mailing list