[pypy-commit] pypy numpy-multidim-shards: in-progress. Get this into some shape so we can run tests

fijal noreply at buildbot.pypy.org
Sat Nov 19 16:41:17 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-multidim-shards
Changeset: r49547:638b988b580e
Date: 2011-11-19 17:40 +0200
http://bitbucket.org/pypy/pypy/changeset/638b988b580e/

Log:	in-progress. Get this into some shape so we can run tests

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -39,12 +39,19 @@
         shape.append(size)
         batch = new_batch
 
+class BroadcastDescription(object):
+    def __init__(self, shape, indices1, indices2):
+        self.shape = shape
+        self.indices1 = indices1
+        self.indices2 = indices2
+
 def shape_agreement(space, shape1, shape2):
     """ Checks agreement about two shapes with respect to broadcasting. Returns
     the resulting shape.
     """
     lshift = 0
     rshift = 0
+    adjustment = False
     if len(shape1) > len(shape2):
         m = len(shape1)
         n = len(shape2)
@@ -56,21 +63,35 @@
         lshift = len(shape1) - len(shape2)
         remainder = shape2
     endshape = [0] * m
+    indices1 = [True] * m
+    indices2 = [True] * m
     for i in range(m - 1, m - n - 1, -1):
         left = shape1[i + lshift]
         right = shape2[i + rshift]
         if left == right:
             endshape[i] = left
         elif left == 1:
+            adjustment = True
             endshape[i] = right
+            indices1[i + lshift] = False
         elif right == 1:
+            adjustment = True
             endshape[i] = left
+            indices2[i + rshift] = False
         else:
             raise OperationError(space.w_ValueError, space.wrap(
                 "frames are not aligned"))
     for i in range(m - n):
+        adjustment = True
         endshape[i] = remainder[i]
+        #if len(shape1) > len(shape2):
+        #    xxx
+        #else:
+        #    xxx
+    #if not adjustment:
+    #    return None
     return endshape
+    return BroadcastDescription(endshape, indices1, indices2)
 
 def descr_new_array(space, w_subtype, w_item_or_iterable, w_dtype=None,
                     w_order=NoneNotWrapped):
@@ -105,7 +126,7 @@
         space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
     )
     arr = NDimArray(size, shape[:], dtype=dtype, order=order)
-    arr_iter = arr.start_iter()
+    arr_iter = arr.start_iter(arr.shape)
     for i in range(len(elems_w)):
         w_elem = elems_w[i]
         dtype.setitem_w(space, arr.storage, arr_iter.offset, w_elem)
@@ -123,12 +144,13 @@
         raise NotImplementedError
 
 class ArrayIterator(BaseIterator):
-    def __init__(self, size, offset=0):
-        self.offset = offset
+    def __init__(self, size):
+        self.offset = 0
         self.size   = size
 
     def next(self):
-        return ArrayIterator(self.size, self.offset + 1)
+        self.offset += 1
+        return self
 
     def done(self):
         return self.offset >= self.size
@@ -137,34 +159,25 @@
         return self.offset
 
 class ViewIterator(BaseIterator):
-    def __init__(self, arr, offset=0, indices=None, done=False):
-        if indices is None:
-            self.indices = [0] * len(arr.shape)
-            self.offset  = arr.start
-        else:
-            self.offset  = offset
-            self.indices = indices
-        self.arr   = arr
-        self._done = done
+    def __init__(self, arr):
+        self.indices = [0] * len(arr.shape)
+        self.offset  = arr.start
+        self.arr     = arr
+        self._done   = False
 
     @jit.unroll_safe
     def next(self):
-        indices = [0] * len(self.arr.shape)
-        for i in range(len(self.arr.shape)):
-            indices[i] = self.indices[i]
-        done = False
-        offset = self.offset
         for i in range(len(self.arr.shape) -1, -1, -1):
-            if indices[i] < self.arr.shape[i] - 1:
-                indices[i] += 1
-                offset += self.arr.shards[i]
+            if self.indices[i] < self.arr.shape[i] - 1:
+                self.indices[i] += 1
+                self.offset += self.arr.shards[i]
                 break
             else:
-                indices[i] = 0
-                offset -= self.arr.backshards[i]
+                self.indices[i] = 0
+                self.offset -= self.arr.backshards[i]
         else:
-            done = True
-        return ViewIterator(self.arr, offset, indices, done)
+            self._done = True
+        return self
 
     def done(self):
         return self._done
@@ -172,13 +185,43 @@
     def get_offset(self):
         return self.offset
 
+class ResizingIterator(object):
+    def __init__(self, iter, shape, orig_indices):
+        self.shape = shape
+        self.indices = [0] * len(shape)
+        self.orig_indices = orig_indices
+        self.iter = iter
+        self._done = False
+
+    @jit.unroll_safe
+    def next(self):
+        for i in range(len(self.shape) -1, -1, -1):
+            if self.indices[i] < self.shape[i] - 1:
+                self.indices[i] += 1
+                if self.orig_indices[i]:
+                    self.iter.next()
+                break
+            else:
+                self.indices[i] = 0
+        else:
+            self._done = True
+        return self
+
+    def get_offset(self):
+        return self.iter.get_offset()
+
+    def done(self):
+        return self._done
+
 class Call2Iterator(BaseIterator):
     def __init__(self, left, right):
         self.left = left
         self.right = right
 
     def next(self):
-        return Call2Iterator(self.left.next(), self.right.next())
+        self.left.next()
+        self.right.next()
+        return self
 
     def done(self):
         return self.left.done() or self.right.done()
@@ -193,7 +236,8 @@
         self.child = child
 
     def next(self):
-        return Call1Iterator(self.child.next())
+        self.child.next()
+        return self
 
     def done(self):
         return self.child.done()
@@ -312,7 +356,7 @@
         reduce_driver = jit.JitDriver(greens=['signature'],
                          reds = ['i', 'result', 'self', 'cur_best', 'dtype'])
         def loop(self):
-            i = self.start_iter()
+            i = self.start_iter(self.shape)
             result = i.get_offset()
             cur_best = self.eval(i)
             i.next()
@@ -339,7 +383,7 @@
 
     def _all(self):
         dtype = self.find_dtype()
-        i = self.start_iter()
+        i = self.start_iter(self.shape)
         while not i.done():
             all_driver.jit_merge_point(signature=self.signature, self=self, dtype=dtype, i=i)
             if not dtype.bool(self.eval(i)):
@@ -351,7 +395,7 @@
 
     def _any(self):
         dtype = self.find_dtype()
-        i = self.start_iter()
+        i = self.start_iter(self.shape)
         while not i.done():
             any_driver.jit_merge_point(signature=self.signature, self=self,
                                        dtype=dtype, i=i)
@@ -403,7 +447,7 @@
                 res.append_slice(str(self_shape), 1, len(self_shape) - 1)
                 res.append(')')
         else:
-            self.to_str(space, 1, res, indent='       ')
+            concrete.to_str(space, 1, res, indent='       ')
         if (dtype is not space.fromcache(interp_dtype.W_Float64Dtype) and
             dtype is not space.fromcache(interp_dtype.W_Int64Dtype)) or \
             not self.find_size():
@@ -488,7 +532,8 @@
 
     def descr_str(self, space):
         ret = StringBuilder()
-        self.to_str(space, 0, ret, ' ')
+        concrete = self.get_concrete()
+        concrete.to_str(space, 0, ret, ' ')
         return space.wrap(ret.build())
 
     def _index_of_single_item(self, space, w_idx):
@@ -633,12 +678,12 @@
         except ValueError:
             pass
         return space.wrap(space.is_true(self.get_concrete().eval(
-            self.start_iter()).wrap(space)))
+            self.start_iter(self.shape)).wrap(space)))
 
     def getitem(self, item):
         raise NotImplementedError
 
-    def start_iter(self):
+    def start_iter(self, res_shape=None):
         raise NotImplementedError
 
     def compute_index(self, space, offset):
@@ -697,7 +742,7 @@
     def eval(self, iter):
         return self.value
 
-    def start_iter(self):
+    def start_iter(self, res_shape=None):
         return ConstantIterator()
 
     def to_str(self, space, comma, builder, indent=' '):
@@ -787,10 +832,10 @@
         assert isinstance(call_sig, signature.Call1)
         return call_sig.func(self.res_dtype, val)
 
-    def start_iter(self):
+    def start_iter(self, res_shape=None):
         if self.forced_result is not None:
-            return self.forced_result.start_iter()
-        return Call1Iterator(self.values.start_iter())
+            return self.forced_result.start_iter(res_shape)
+        return Call1Iterator(self.values.start_iter(res_shape))
 
 class Call2(VirtualArray):
     """
@@ -814,10 +859,13 @@
             pass
         return self.right.find_size()
 
-    def start_iter(self):
+    def start_iter(self, res_shape=None):
         if self.forced_result is not None:
-            return self.forced_result.start_iter()
-        return Call2Iterator(self.left.start_iter(), self.right.start_iter())
+            return self.forced_result.start_iter(res_shape)
+        if res_shape is None:
+            res_shape = self.shape # we still force the shape on children
+        return Call2Iterator(self.left.start_iter(res_shape),
+                             self.right.start_iter(res_shape))
 
     def _eval(self, iter):
         assert isinstance(iter, Call2Iterator)
@@ -895,15 +943,12 @@
         return self.parent.find_dtype()
 
     def setslice(self, space, w_value):
-        if isinstance(w_value, NDimArray):
-            if self.shape != w_value.shape:
-                raise OperationError(space.w_TypeError, space.wrap(
-                    "wrong assignment"))
-        self._sliceloop(w_value)
+        res_shape = shape_agreement(space, self.shape, w_value.shape)
+        self._sliceloop(w_value, res_shape)
 
-    def _sliceloop(self, source):
-        source_iter = source.start_iter()
-        res_iter = self.start_iter()
+    def _sliceloop(self, source, res_shape):
+        source_iter = source.start_iter(res_shape)
+        res_iter = self.start_iter(res_shape)
         while not res_iter.done():
             slice_driver.jit_merge_point(signature=source.signature,
                                          self=self, source=source,
@@ -914,8 +959,11 @@
             source_iter = source_iter.next()
             res_iter = res_iter.next()
 
-    def start_iter(self, offset=0, indices=None):
-        return ViewIterator(self, offset=offset, indices=indices)
+    def start_iter(self, res_shape=None):
+        if res_shape is not None and res_shape != self.shape:
+            raise NotImplementedError # xxx
+            #return ResizingIterator(ViewIterator(self), res_shape, orig_indices)
+        return ViewIterator(self)
 
     def setitem(self, item, value):
         self.parent.setitem(item, value)
@@ -967,9 +1015,11 @@
         self.invalidated()
         self.dtype.setitem(self.storage, item, value)
 
-    def start_iter(self, offset=0, indices=None):
+    def start_iter(self, res_shape=None):
         if self.order == 'C':
-            return ArrayIterator(self.size, offset=offset)
+            if res_shape is not None and res_shape != self.shape:
+                raise NotImplementedError # xxx
+            return ArrayIterator(self.size)
         raise NotImplementedError  # use ViewIterator simply, test it
 
     def __del__(self):
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -56,7 +56,7 @@
             space, obj.find_dtype(),
             promote_to_largest=True
         )
-        start = obj.start_iter()
+        start = obj.start_iter(obj.shape)
         if self.identity is None:
             if size == 0:
                 raise operationerrfmt(space.w_ValueError, "zero-size array to "
@@ -123,7 +123,7 @@
 
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call2,
-            convert_to_array, Scalar)
+            convert_to_array, Scalar, shape_agreement)
 
         [w_lhs, w_rhs] = args_w
         w_lhs = convert_to_array(space, w_lhs)
@@ -146,7 +146,8 @@
         new_sig = signature.Signature.find_sig([
             self.signature, w_lhs.signature, w_rhs.signature
         ])
-        w_res = Call2(new_sig, w_lhs.shape or w_rhs.shape, calc_dtype,
+        new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+        w_res = Call2(new_sig, new_shape, calc_dtype,
                       res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -843,8 +843,17 @@
         c = b + b
         assert c[1][1] == 12
 
-    def test_broadcast(self):
-        skip("not working")
+    def test_broadcast_ufunc(self):
+        from numpy import array
+        a = array([[1, 2], [3, 4], [5, 6]])
+        b = array([5, 6])
+        #print a + b
+        c = ((a + b) == [[1+5, 2+6], [3+5, 4+6], [5+5, 6+6]])
+        print c
+        print c.all()
+        assert c.all()
+
+    def test_broadcast_setslice(self):
         import numpy
         a = numpy.zeros((100, 100))
         b = numpy.ones(100)


More information about the pypy-commit mailing list