[pypy-commit] pypy refactor-signature: unify Broadcast and View

fijal noreply at buildbot.pypy.org
Mon Dec 19 13:19:26 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: refactor-signature
Changeset: r50702:2d23987ef2d2
Date: 2011-12-19 14:18 +0200
http://bitbucket.org/pypy/pypy/changeset/2d23987ef2d2/

Log:	unify Broadcast and View

diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -61,11 +61,28 @@
         return self.offset
 
 class ViewIterator(BaseIterator):
-    def __init__(self, arr):
+    def __init__(self, arr, res_shape=None):
         self.indices = [0] * len(arr.shape)
         self.offset  = arr.start
-        self.arr     = arr
         self._done   = False
+        if res_shape is not None and res_shape != arr.shape:
+            self.strides = []
+            self.backstrides = []
+            for i in range(len(arr.shape)):
+                if arr.shape[i] == 1:
+                    self.strides.append(0)
+                    self.backstrides.append(0)
+                else:
+                    self.strides.append(arr.strides[i])
+                    self.backstrides.append(arr.backstrides[i])
+            self.strides = [0] * (len(res_shape) - len(arr.shape)) + self.strides
+            self.backstrides = [0] * (len(res_shape) - len(arr.shape)) + self.backstrides
+            self.res_shape = res_shape
+        else:
+            self.strides = arr.strides
+            self.backstrides = arr.backstrides
+            self.res_shape = arr.shape
+
 
     @jit.unroll_safe
     def next(self, shapelen):
@@ -75,59 +92,6 @@
             indices[i] = self.indices[i]
         done = False
         for i in range(shapelen - 1, -1, -1):
-            if indices[i] < self.arr.shape[i] - 1:
-                indices[i] += 1
-                offset += self.arr.strides[i]
-                break
-            else:
-                indices[i] = 0
-                offset -= self.arr.backstrides[i]
-        else:
-            done = True
-        res = instantiate(ViewIterator)
-        res.offset = offset
-        res.indices = indices
-        res.arr = self.arr
-        res._done = done
-        return res
-
-    def done(self):
-        return self._done
-
-    def get_offset(self):
-        return self.offset
-
-class BroadcastIterator(BaseIterator):
-    '''Like a view iterator, but will repeatedly access values
-       for all iterations across a res_shape, folding the offset
-       using mod() arithmetic
-    '''
-    def __init__(self, arr, res_shape):
-        self.indices = [0] * len(res_shape)
-        self.offset  = arr.start
-        #strides are 0 where original shape==1
-        self.strides = []
-        self.backstrides = []
-        for i in range(len(arr.shape)):
-            if arr.shape[i] == 1:
-                self.strides.append(0)
-                self.backstrides.append(0)
-            else:
-                self.strides.append(arr.strides[i])
-                self.backstrides.append(arr.backstrides[i])
-        self.res_shape = res_shape
-        self.strides = [0] * (len(res_shape) - len(arr.shape)) + self.strides
-        self.backstrides = [0] * (len(res_shape) - len(arr.shape)) + self.backstrides
-        self._done = False
-
-    @jit.unroll_safe
-    def next(self, shapelen):
-        offset = self.offset
-        indices = [0] * shapelen
-        _done = False
-        for i in range(shapelen):
-            indices[i] = self.indices[i]
-        for i in range(shapelen - 1, -1, -1):
             if indices[i] < self.res_shape[i] - 1:
                 indices[i] += 1
                 offset += self.strides[i]
@@ -136,14 +100,14 @@
                 indices[i] = 0
                 offset -= self.backstrides[i]
         else:
-            _done = True
-        res = instantiate(BroadcastIterator)
+            done = True
+        res = instantiate(ViewIterator)
+        res.offset = offset
         res.indices = indices
-        res.offset = offset
-        res._done = _done
         res.strides = self.strides
         res.backstrides = self.backstrides
         res.res_shape = self.res_shape
+        res._done = done
         return res
 
     def done(self):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -875,6 +875,8 @@
 class ConcreteArray(BaseArray):
     """ An array that have actual storage, whether owned or not
     """
+    _immutable_fields_ = ['storage']
+
     def __init__(self, size, shape, dtype, order='C', parent=None):
         self.size = size
         self.parent = parent
@@ -1010,8 +1012,6 @@
     """ A class representing contiguous array. We know that each iteration
     by say ufunc will increase the data index by one
     """
-    _immutable_fields_ = ['storage']
-
     def copy(self):
         array = W_NDimArray(self.size, self.shape[:], self.dtype, self.order)
         rffi.c_memcpy(
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -1,25 +1,10 @@
 from pypy.rlib.objectmodel import r_dict, compute_identity_hash, compute_hash
 from pypy.rlib.rarithmetic import intmask
 from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
-     BroadcastIterator, OneDimIterator, ConstantIterator
+     OneDimIterator, ConstantIterator
 from pypy.rpython.lltypesystem.llmemory import cast_ptr_to_adr
 from pypy.rlib.jit import hint, unroll_safe, promote
 
-# def components_eq(lhs, rhs):
-#     if len(lhs) != len(rhs):
-#         return False
-#     for i in range(len(lhs)):
-#         v1, v2 = lhs[i], rhs[i]
-#         if type(v1) is not type(v2) or not v1.eq(v2):
-#             return False
-#     return True
-
-# def components_hash(components):
-#     res = 0x345678
-#     for component in components:
-#         res = intmask((1000003 * res) ^ component.hash())
-#     return res
-
 def sigeq(one, two):
     return one.eq(two)
 


More information about the pypy-commit mailing list