[pypy-commit] pypy refactor-signature: start refactoring - change test_base and make it pass

fijal noreply at buildbot.pypy.org
Wed Dec 14 11:02:56 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: refactor-signature
Changeset: r50488:9d888e0a1e4b
Date: 2011-12-14 12:02 +0200
http://bitbucket.org/pypy/pypy/changeset/9d888e0a1e4b/

Log:	start refactoring - change test_base and make it pass

diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -28,11 +28,6 @@
         self.char = char
         self.w_box_type = w_box_type
         self.alternate_constructors = alternate_constructors
-        self.array_signature = signature.ArraySignature()
-        self.scalar_signature = signature.ScalarSignature()
-        self.forced_signature = signature.ForcedSignature()
-        #self.flatiter_signature = signature.FlatiterSignature()
-        #self.view_signature = signature.ViewSignature()
 
     def malloc(self, length):
         # XXX find out why test_zjit explodes with tracking of allocations
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -2,6 +2,13 @@
 from pypy.rlib import jit
 from pypy.rlib.objectmodel import instantiate
 
+class NumpyEvalFrame(object):
+    def __init__(self, iterators):
+        self.iterators = iterators
+
+    def next(self, shapelen):
+        xxx
+
 # Iterators for arrays
 # --------------------
 # all those iterators with the exception of BroadcastIterator iterate over the
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,8 +7,7 @@
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
-from pypy.rlib.objectmodel import instantiate
-
+from pypy.module.micronumpy.interp_iter import NumpyEvalFrame, ArrayIterator
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'signature'],
@@ -199,7 +198,7 @@
     return new_strides
 
 class BaseArray(Wrappable):
-    _attrs_ = ["invalidates", "signature", "shape", "strides", "backstrides",
+    _attrs_ = ["invalidates", "shape", "strides", "backstrides",
                "start", 'order']
 
     _immutable_fields_ = ['start', "order"]
@@ -310,7 +309,7 @@
             reds=['result', 'idx', 'i', 'self', 'cur_best', 'dtype']
         )
         def loop(self):
-            i = self.start_iter()
+            i = self.signature.create_iter(self, {})
             cur_best = self.eval(i)
             shapelen = len(self.shape)
             i = i.next(shapelen)
@@ -709,11 +708,19 @@
         raise NotImplementedError
 
     def start_iter(self, res_shape=None):
-        raise NotImplementedError
+        all_iters = self.signature.create_iter(self, {}, res_shape)
+        return NumpyEvalFrame(all_iters)
 
     def descr_debug_repr(self, space):
         return space.wrap(self.signature.debug_repr())
 
+    def find_sig(self):
+        """ find a correct signature for the array
+        """
+        sig = self.create_sig()
+        sig.invent_numbering()
+        return signature.find_sig(sig)
+
 def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
         return w_obj
@@ -739,7 +746,6 @@
         BaseArray.__init__(self, [], 'C')
         self.dtype = dtype
         self.value = value
-        self.signature = dtype.scalar_signature
 
     def find_size(self):
         return 1
@@ -756,9 +762,6 @@
     def eval(self, iter):
         return self.value
 
-    def start_iter(self, res_shape=None):
-        return ConstantIterator()
-
     def to_str(self, space, comma, builder, indent=' ', use_ellipsis=False):
         builder.append(self.dtype.itemtype.str_format(self.value))
 
@@ -770,14 +773,16 @@
         # so in order to have a consistent API, let it go through.
         pass
 
+    def create_sig(self):
+        return signature.ScalarSignature(self.dtype)
+
 class VirtualArray(BaseArray):
     """
     Class for representing virtual arrays, such as binary ops or ufuncs
     """
-    def __init__(self, signature, shape, res_dtype, order):
+    def __init__(self, shape, res_dtype, order):
         BaseArray.__init__(self, shape, order)
         self.forced_result = None
-        self.signature = signature
         self.res_dtype = res_dtype
 
     def _del_sources(self):
@@ -786,10 +791,10 @@
 
     def compute(self):
         i = 0
-        signature = self.signature
         result_size = self.find_size()
         result = W_NDimArray(result_size, self.shape, self.find_dtype())
         shapelen = len(self.shape)
+        xxx
         i = self.start_iter()
         ri = result.start_iter()
         while not ri.done():
@@ -805,7 +810,6 @@
     def force_if_needed(self):
         if self.forced_result is None:
             self.forced_result = self.compute()
-            self.signature = self.find_dtype().forced_signature
             self._del_sources()
 
     def get_concrete(self):
@@ -834,10 +838,11 @@
 
 
 class Call1(VirtualArray):
-    def __init__(self, signature, shape, res_dtype, values, order):
-        VirtualArray.__init__(self, signature, shape, res_dtype,
+    def __init__(self, ufunc, shape, res_dtype, values, order):
+        VirtualArray.__init__(self, shape, res_dtype,
                               values.order)
         self.values = values
+        self.ufunc = ufunc
 
     def _del_sources(self):
         self.values = None
@@ -855,18 +860,16 @@
         assert isinstance(sig, signature.Call1)
         return sig.unfunc(self.res_dtype, val)
 
-    def start_iter(self, res_shape=None):
-        if self.forced_result is not None:
-            return self.forced_result.start_iter(res_shape)
-        return Call1Iterator(self.values.start_iter(res_shape))
+    def create_sig(self):
+        return signature.Call1(self.ufunc, self.values.create_sig())
 
 class Call2(VirtualArray):
     """
     Intermediate class for performing binary operations.
     """
-    def __init__(self, signature, shape, calc_dtype, res_dtype, left, right):
-        # XXX do something if left.order != right.order
-        VirtualArray.__init__(self, signature, shape, res_dtype, left.order)
+    def __init__(self, ufunc, shape, calc_dtype, res_dtype, left, right):
+        VirtualArray.__init__(self, shape, res_dtype, left.order)
+        self.ufunc = ufunc
         self.left = left
         self.right = right
         self.calc_dtype = calc_dtype
@@ -881,14 +884,6 @@
     def _find_size(self):
         return self.size
 
-    def start_iter(self, res_shape=None):
-        if self.forced_result is not None:
-            return self.forced_result.start_iter(res_shape)
-        if res_shape is None:
-            res_shape = self.shape  # we still force the shape on children
-        return Call2Iterator(self.left.start_iter(res_shape),
-                             self.right.start_iter(res_shape))
-
     def _eval(self, iter):
         assert isinstance(iter, Call2Iterator)
         lhs = self.left.eval(iter.left).convert_to(self.calc_dtype)
@@ -897,6 +892,10 @@
         assert isinstance(sig, signature.Call2)
         return sig.binfunc(self.calc_dtype, lhs, rhs)
 
+    def create_sig(self):
+        return signature.Call2(self.ufunc, self.left.create_sig(),
+                               self.right.create_sig())
+
 class ViewArray(BaseArray):
     """
     Class for representing views of arrays, they will reflect changes of parent
@@ -974,7 +973,6 @@
         if isinstance(parent, W_NDimSlice):
             parent = parent.parent
         ViewArray.__init__(self, parent, strides, backstrides, shape)
-        self.signature = signature.find_sig(signature.ViewSignature(parent.signature))
         self.start = start
         self.size = 1
         for sh in shape:
@@ -1005,12 +1003,12 @@
             source_iter = source_iter.next(shapelen)
             res_iter = res_iter.next(shapelen)
 
-    def start_iter(self, res_shape=None):
-        if res_shape is not None and res_shape != self.shape:
-            return BroadcastIterator(self, res_shape)
-        if len(self.shape) == 1:
-            return OneDimIterator(self.start, self.strides[0], self.shape[0])
-        return ViewIterator(self)
+    # def start_iter(self, res_shape=None):
+    #     if res_shape is not None and res_shape != self.shape:
+    #         return BroadcastIterator(self, res_shape)
+    #     if len(self.shape) == 1:
+    #         return OneDimIterator(self.start, self.strides[0], self.shape[0])
+    #     return ViewIterator(self)
 
     def setitem(self, item, value):
         self.parent.setitem(item, value)
@@ -1025,6 +1023,9 @@
             a_iter = a_iter.next(len(array.shape))
         return array
 
+    def create_sig(self):
+        return signature.ViewSignature(self.parent.create_sig())
+
 class W_NDimArray(BaseArray):
     """ A class representing contiguous array. We know that each iteration
     by say ufunc will increase the data index by one
@@ -1034,7 +1035,6 @@
         self.size = size
         self.dtype = dtype
         self.storage = dtype.malloc(size)
-        self.signature = dtype.array_signature
 
     def get_concrete(self):
         return self
@@ -1073,17 +1073,20 @@
         self.invalidated()
         self.dtype.setitem(self.storage, item, value)
 
-    def start_iter(self, res_shape=None):
-        if self.order == 'C':
-            if res_shape is not None and res_shape != self.shape:
-                return BroadcastIterator(self, res_shape)
-            return ArrayIterator(self.size)
-        raise NotImplementedError  # use ViewIterator simply, test it
+    # def start_iter(self, res_shape=None):
+    #     if self.order == 'C':
+    #         if res_shape is not None and res_shape != self.shape:
+    #             return BroadcastIterator(self, res_shape)
+    #         return ArrayIterator(self.size)
+    #     raise NotImplementedError  # use ViewIterator simply, test it
 
     def setshape(self, space, new_shape):
         self.shape = new_shape
         self.calc_strides(new_shape)
 
+    def create_sig(self):
+        return signature.ArraySignature(self.dtype)
+
     def __del__(self):
         lltype.free(self.storage, flavor='raw', track_allocation=False)
 
@@ -1134,11 +1137,11 @@
     )
     arr = W_NDimArray(size, shape[:], dtype=dtype, order=order)
     shapelen = len(shape)
-    iters = arr.signature.create_iterator()
-    arr_iter = arr.start_iter(arr.shape)
+    arr_iter = ArrayIterator(arr)
     for i in range(len(elems_w)):
         w_elem = elems_w[i]
-        dtype.setitem(arr.storage, arr_iter.offset, dtype.coerce(space, w_elem))
+        dtype.setitem(arr.storage, arr_iter.offset,
+                      dtype.coerce(space, w_elem))
         arr_iter = arr_iter.next(shapelen)
     return arr
 
@@ -1241,14 +1244,12 @@
         self.shapelen = len(arr.shape)
         self.arr = arr
         self.iter = self.start_iter()
-        self.signature = signature.find_sig(signature.FlatiterSignature(
-            arr.signature))
 
-    def start_iter(self, res_shape=None):
-        if res_shape is not None and res_shape != self.shape:
-            return BroadcastIterator(self, res_shape)
-        return OneDimIterator(self.arr.start, self.strides[0],
-                              self.shape[0])
+    # def start_iter(self, res_shape=None):
+    #     if res_shape is not None and res_shape != self.shape:
+    #         return BroadcastIterator(self, res_shape)
+    #     return OneDimIterator(self.arr.start, self.strides[0],
+    #                           self.shape[0])
 
     def find_dtype(self):
         return self.arr.find_dtype()
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -117,10 +117,7 @@
         if isinstance(w_obj, Scalar):
             return self.func(res_dtype, w_obj.value.convert_to(res_dtype))
 
-        new_sig = signature.find_sig(signature.Call1(self.func,
-                                                     self.name,
-                                                     w_obj.signature))
-        w_res = Call1(new_sig, w_obj.shape, res_dtype, w_obj, w_obj.order)
+        w_res = Call1(self.func, w_obj.shape, res_dtype, w_obj, w_obj.order)
         w_obj.add_invalidates(w_res)
         return w_res
 
@@ -158,12 +155,8 @@
                 w_rhs.value.convert_to(calc_dtype)
             )
 
-        new_sig = signature.find_sig(signature.Call2(self.func,
-                                                     self.name,
-                                                     w_lhs.signature,
-                                                     w_rhs.signature))
         new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
-        w_res = Call2(new_sig, new_shape, calc_dtype,
+        w_res = Call2(self.func, new_shape, calc_dtype,
                       res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -1,7 +1,7 @@
 from pypy.rlib.objectmodel import r_dict, compute_identity_hash, compute_hash
 from pypy.rlib.rarithmetic import intmask
 from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
-     BroadcastIterator, OneDimIterator
+     BroadcastIterator, OneDimIterator, ConstantIterator
 
 
 # def components_eq(lhs, rhs):
@@ -31,14 +31,11 @@
     return known_sigs.setdefault(sig, sig)
 
 class Signature(object):
-    def eq(self, other):
-        return self is other
+    def create_iter(self, array, cache, res_shape=None):
+        raise NotImplementedError
 
-    def hash(self):
-        return compute_hash(self)
-
-    def create_iter(self, array, cache):
-        raise NotImplementedError
+    def invent_numbering(self):
+        pass # XXX
 
 class ViewSignature(Signature):
     def __init__(self, child):
@@ -55,17 +52,33 @@
     def debug_repr(self):
         return 'Slice(%s)' % self.child.debug_repr()
 
-    def create_iter(self, array, cache):
-        xxxx
+class ArraySignature(Signature):
+    def __init__(self, dtype):
+        self.dtype = dtype
 
-class ArraySignature(Signature):
+    def eq(self, other):
+        if type(self) is not type(other):
+            return False
+        return self.dtype is other.dtype
+
+    def hash(self):
+        return compute_identity_hash(self.dtype)
+
     def debug_repr(self):
         return 'Array'
 
-    def create_iter(self, array, cache):
-        xxx
+class ScalarSignature(Signature):
+    def __init__(self, dtype):
+        self.dtype = dtype
 
-class ScalarSignature(Signature):
+    def eq(self, other):
+        if type(self) is not type(other):
+            return False
+        return self.dtype is other.dtype
+
+    def hash(self):
+        return compute_identity_hash(self.dtype)
+
     def debug_repr(self):
         return 'Scalar'
 
@@ -74,16 +87,15 @@
         return 'FlatIter(%s)' % self.child.debug_repr()
 
 class Call1(Signature):
-    def __init__(self, func, name, child):
+    def __init__(self, func, child):
         self.unfunc = func
-        self.name = name
         self.child = child
 
     def hash(self):
-        return compute_hash(self.name) ^ self.child.hash() << 1
+        return compute_identity_hash(self.unfunc) ^ self.child.hash() << 1
 
     def eq(self, other):
-        if type(other) is not type(self):
+        if type(self) is not type(other):
             return False
         return self.unfunc is other.unfunc and self.child.eq(other.child)
 
@@ -92,18 +104,17 @@
                                   self.child.debug_repr())
 
 class Call2(Signature):
-    def __init__(self, func, name, left, right):
+    def __init__(self, func, left, right):
         self.binfunc = func
-        self.name = name
         self.left = left
         self.right = right
 
     def hash(self):
-        return (compute_hash(self.name) ^ (self.left.hash() << 1) ^
+        return (compute_identity_hash(self.binfunc) ^ (self.left.hash() << 1) ^
                 (self.right.hash() << 2))
 
     def eq(self, other):
-        if type(other) is not type(self):
+        if type(self) is not type(other):
             return False
         return (self.binfunc is other.binfunc and
                 self.left.eq(other.left) and self.right.eq(other.right))
@@ -113,36 +124,5 @@
                                       self.left.debug_repr(),
                                       self.right.debug_repr())
 
-class ForcedSignature(Signature):
-    def debug_repr(self):
-        return 'Forced'
-
 class ReduceSignature(Call2):
     pass
-
-# class Signature(BaseSignature):
-#     _known_sigs = r_dict(components_eq, components_hash)
-
-#     _attrs_ = ["components"]
-#     _immutable_fields_ = ["components[*]"]
-
-#     def __init__(self, components):
-#         self.components = components
-
-#     @staticmethod
-#     def find_sig(components):
-#         return Signature._known_sigs.setdefault(components, Signature(components))
-
-# class Call1(BaseSignature):
-#     _immutable_fields_ = ["func", "name"]
-
-#     def __init__(self, func):
-#         self.func = func
-#         self.name = func.func_name
-
-# class Call2(BaseSignature):
-#     _immutable_fields_ = ["func", "name"]
-
-#     def __init__(self, func):
-#         self.func = func
-#         self.name = func.func_name
diff --git a/pypy/module/micronumpy/test/test_base.py b/pypy/module/micronumpy/test/test_base.py
--- a/pypy/module/micronumpy/test/test_base.py
+++ b/pypy/module/micronumpy/test/test_base.py
@@ -17,18 +17,18 @@
         ar = W_NDimArray(10, [10], dtype=float64_dtype)
         v1 = ar.descr_add(space, ar)
         v2 = ar.descr_add(space, Scalar(float64_dtype, 2.0))
-        assert v1.signature is not v2.signature
+        assert v1.find_sig() is not v2.find_sig()
         v3 = ar.descr_add(space, Scalar(float64_dtype, 1.0))
-        assert v2.signature is v3.signature
+        assert v2.find_sig() is v3.find_sig()
         v4 = ar.descr_add(space, ar)
-        assert v1.signature is v4.signature
+        assert v1.find_sig() is v4.find_sig()
 
         bool_ar = W_NDimArray(10, [10], dtype=bool_dtype)
         v5 = ar.descr_add(space, bool_ar)
-        assert v5.signature is not v1.signature
-        assert v5.signature is not v2.signature
+        assert v5.find_sig() is not v1.find_sig()
+        assert v5.find_sig() is not v2.find_sig()
         v6 = ar.descr_add(space, bool_ar)
-        assert v5.signature is v6.signature
+        assert v5.find_sig() is v6.find_sig()
 
     def test_slice_signature(self, space):
         float64_dtype = get_dtype_cache(space).w_float64dtype
@@ -36,11 +36,11 @@
         ar = W_NDimArray(10, [10], dtype=float64_dtype)
         v1 = ar.descr_getitem(space, space.wrap(slice(1, 3, 1)))
         v2 = ar.descr_getitem(space, space.wrap(slice(4, 6, 1)))
-        assert v1.signature is v2.signature
+        assert v1.find_sig() is v2.find_sig()
 
         v3 = v2.descr_add(space, v1)
         v4 = v1.descr_add(space, v2)
-        assert v3.signature is v4.signature
+        assert v3.find_sig() is v4.find_sig()
 
 class TestUfuncCoerscion(object):
     def test_binops(self, space):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -8,9 +8,6 @@
 
 
 class MockDtype(object):
-    array_signature = signature.ArraySignature()
-    scalar_signature = signature.ScalarSignature()
-
     def malloc(self, size):
         return None
 


More information about the pypy-commit mailing list