[pypy-svn] r66429 - in pypy/branch/pyjitpl5/pypy/module/micronumpy: . test

fijal at codespeak.net fijal at codespeak.net
Mon Jul 20 11:00:48 CEST 2009


Author: fijal
Date: Mon Jul 20 11:00:47 2009
New Revision: 66429

Modified:
   pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py
   pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py
Log:
Implement some basics of multi-dim arrays


Modified: pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py	(original)
+++ pypy/branch/pyjitpl5/pypy/module/micronumpy/numarray.py	Mon Jul 20 11:00:47 2009
@@ -5,13 +5,15 @@
 from pypy.interpreter.gateway import interp2app, NoneNotWrapped
 from pypy.rlib.debug import make_sure_not_resized
 
-class NumArray(Wrappable):
+class BaseNumArray(Wrappable):
+    pass
+
+class NumArray(BaseNumArray):
     def __init__(self, space, dim, dtype):
         self.dim = dim
         self.space = space
         # ignore dtype for now
-        assert len(dim) == 1
-        self.storage = [0] * dim[0]
+        self.storage = [0] * dim
         make_sure_not_resized(self.storage)
 
     def descr_getitem(self, index):
@@ -44,11 +46,70 @@
     __len__     = interp2app(NumArray.descr_len),
 )
 
+def compute_pos(space, indexes, dim):
+    current = 1
+    pos = 0
+    for i in range(len(indexes)):
+        index = indexes[i]
+        d = dim[i]
+        if index >= d or index <= -d - 1:
+            raise OperationError(space.w_IndexError,
+                                 space.wrap("invalid index"))
+        if index < 0:
+            index = d + index
+        pos += index * current
+        current *= d
+    return pos
+
+class MultiDimArray(BaseNumArray):
+    def __init__(self, space, dim, dtype):
+        self.dim = dim
+        self.space = space
+        # ignore dtype for now
+        size = 1
+        for el in dim:
+            size *= el
+        self.storage = [0] * size
+        make_sure_not_resized(self.storage)
+
+    def _unpack_indexes(self, space, w_index):
+        indexes = [space.int_w(w_i) for w_i in space.viewiterable(w_index)]
+        if len(indexes) != len(self.dim):
+            raise OperationError(space.w_IndexError, space.wrap(
+                'Wrong index'))
+        return indexes
+
+    def descr_getitem(self, w_index):
+        space = self.space
+        indexes = self._unpack_indexes(space, w_index)
+        pos = compute_pos(space, indexes, self.dim)
+        return space.wrap(self.storage[pos])
+    descr_getitem.unwrap_spec = ['self', W_Root]
+
+    def descr_setitem(self, w_index, value):
+        space = self.space
+        indexes = self._unpack_indexes(space, w_index)
+        pos = compute_pos(space, indexes, self.dim)
+        self.storage[pos] = value
+        return space.w_None
+    descr_setitem.unwrap_spec = ['self', W_Root, int]
+
+    def descr_len(self):
+        return self.space.wrap(self.dim[0])
+    descr_len.unwrap_spec = ['self']
+
+MultiDimArray.typedef = TypeDef(
+    'NumArray',
+    __getitem__ = interp2app(MultiDimArray.descr_getitem),
+    __setitem__ = interp2app(MultiDimArray.descr_setitem),
+    __len__     = interp2app(MultiDimArray.descr_len),
+)
+
 def unpack_dim(space, w_dim):
     if space.is_true(space.isinstance(w_dim, space.w_int)):
         return [space.int_w(w_dim)]
-    else:
-        raise NotImplementedError
+    dim_w = space.viewiterable(w_dim)
+    return [space.int_w(w_i) for w_i in dim_w]
 
 def unpack_dtype(space, w_dtype):
     if space.is_w(w_dtype, space.w_int):
@@ -59,5 +120,8 @@
 def zeros(space, w_dim, w_dtype):
     dim = unpack_dim(space, w_dim)
     dtype = unpack_dtype(space, w_dtype)
-    return space.wrap(NumArray(space, dim, dtype))
+    if len(dim) == 1:
+        return space.wrap(NumArray(space, dim[0], dtype))
+    else:
+        return space.wrap(MultiDimArray(space, dim, dtype))
 zeros.unwrap_spec = [ObjSpace, W_Root, W_Root]

Modified: pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py	(original)
+++ pypy/branch/pyjitpl5/pypy/module/micronumpy/test/test_numpy.py	Mon Jul 20 11:00:47 2009
@@ -39,3 +39,27 @@
         assert x[4] == 0
         assert len(x) == 5
         raises(ValueError, minimum, ar, zeros(3, dtype=int))
+
+class AppTestMultiDim(object):
+    def setup_class(cls):
+        cls.space = gettestobjspace(usemodules=('micronumpy',))
+
+    def test_multidim(self):
+        from numpy import zeros
+        ar = zeros((3, 3), dtype=int)
+        assert ar[0, 2] == 0
+        raises(IndexError, ar.__getitem__, (3, 0))
+        assert ar[-2, 1] == 0
+
+    def test_multidim_getset(self):
+        from numpy import zeros
+        ar = zeros((3, 3, 3), dtype=int)
+        ar[1, 2, 1] = 3
+        assert ar[1, 2, 1] == 3
+        assert ar[-2, 2, 1] == 3
+        assert ar[2, 2, 1] == 0
+        assert ar[-2, 2, -2] == 3
+
+    def test_len(self):
+        from numpy import zeros
+        assert len(zeros((3, 2, 1), dtype=int)) == 3



More information about the Pypy-commit mailing list