[pypy-commit] pypy refactor-signature: Improve a bit on max() bridge

fijal noreply at buildbot.pypy.org
Thu Dec 15 10:19:56 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: refactor-signature
Changeset: r50536:dfe668607a47
Date: 2011-12-15 11:19 +0200
http://bitbucket.org/pypy/pypy/changeset/dfe668607a47/

Log:	Improve a bit on max() bridge

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -882,7 +882,8 @@
     def create_sig(self):
         if self.forced_result is not None:
             return self.forced_result.create_sig()
-        return signature.Call2(self.ufunc, self.name, self.left.create_sig(),
+        return signature.Call2(self.ufunc, self.name, self.calc_dtype,
+                               self.left.create_sig(),
                                self.right.create_sig())
 
 class ViewArray(BaseArray):
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -68,7 +68,7 @@
             promote_to_largest=True
         )
         shapelen = len(obj.shape)
-        sig = find_sig(ReduceSignature(self.func, self.name,
+        sig = find_sig(ReduceSignature(self.func, self.name, dtype,
                                        ScalarSignature(dtype),
                                        obj.create_sig()))
         frame = sig.create_frame(obj)
@@ -98,6 +98,8 @@
 class W_Ufunc1(W_Ufunc):
     argcount = 1
 
+    _immutable_fields_ = ["func", "name"]
+
     def __init__(self, func, name, promote_to_float=False, promote_bools=False,
         identity=None):
 
@@ -125,7 +127,7 @@
 
 
 class W_Ufunc2(W_Ufunc):
-    _immutable_fields_ = ["comparison_func", "func"]
+    _immutable_fields_ = ["comparison_func", "func", "name"]
     argcount = 2
 
     def __init__(self, func, name, promote_to_float=False, promote_bools=False,
@@ -158,7 +160,8 @@
             )
 
         new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
-        w_res = Call2(self.func, self.name, new_shape, calc_dtype,
+        w_res = Call2(self.func, self.name,
+                      new_shape, calc_dtype,
                       res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -61,6 +61,9 @@
             self.iterators[i] = self.iterators[i].next(shapelen)
 
 class Signature(object):
+    _attrs_ = ['iter_no']
+    _immutable_fields_ = ['iter_no']
+    
     def invent_numbering(self):
         cache = r_dict(sigeq, sighash)
         allnumbers = []
@@ -81,12 +84,15 @@
         return NumpyEvalFrame(iterlist)
 
 class ConcreteSignature(Signature):
+    _immutable_fields_ = ['dtype']
+
     def __init__(self, dtype):
         self.dtype = dtype
 
     def eq(self, other):
         if type(self) is not type(other):
             return False
+        assert isinstance(other, ConcreteSignature)
         return self.dtype is other.dtype
 
     def hash(self):
@@ -108,7 +114,7 @@
         arr = arr.get_concrete()
         assert isinstance(arr, W_NDimArray)
         iter = frame.iterators[self.iter_no]
-        return arr.dtype.getitem(arr.storage, iter.offset)
+        return self.dtype.getitem(arr.storage, iter.offset)
 
 class ScalarSignature(ConcreteSignature):
     def debug_repr(self):
@@ -125,12 +131,15 @@
         return arr.value
 
 class ViewSignature(Signature):
+    _immutable_fields_ = ['child']
+
     def __init__(self, child):
         self.child = child
     
     def eq(self, other):
         if type(self) is not type(other):
             return False
+        assert isinstance(other, ViewSignature)
         return self.child.eq(other.child)
 
     def hash(self):
@@ -164,6 +173,8 @@
         raise NotImplementedError
 
 class Call1(Signature):
+    _immutable_fields_ = ['unfunc', 'name', 'child']
+
     def __init__(self, func, name, child):
         self.unfunc = func
         self.child = child
@@ -175,6 +186,7 @@
     def eq(self, other):
         if type(self) is not type(other):
             return False
+        assert isinstance(other, Call1)
         return self.unfunc is other.unfunc and self.child.eq(other.child)
 
     def debug_repr(self):
@@ -195,11 +207,14 @@
         return self.unfunc(arr.res_dtype, v)
 
 class Call2(Signature):
-    def __init__(self, func, name, left, right):
+    _immutable_fields_ = ['binfunc', 'name', 'calc_dtype', 'left', 'right']
+    
+    def __init__(self, func, name, calc_dtype, left, right):
         self.binfunc = func
         self.left = left
         self.right = right
         self.name = name
+        self.calc_dtype = calc_dtype
 
     def hash(self):
         return (compute_hash(self.name) ^ (self.left.hash() << 1) ^
@@ -208,7 +223,9 @@
     def eq(self, other):
         if type(self) is not type(other):
             return False
+        assert isinstance(other, Call2)
         return (self.binfunc is other.binfunc and
+                self.calc_dtype is other.calc_dtype and
                 self.left.eq(other.left) and self.right.eq(other.right))
 
     def _invent_numbering(self, cache, allnumbers):
@@ -225,9 +242,9 @@
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call2
         assert isinstance(arr, Call2)
-        lhs = self.left.eval(frame, arr.left).convert_to(arr.calc_dtype)
-        rhs = self.right.eval(frame, arr.right).convert_to(arr.calc_dtype)
-        return self.binfunc(arr.calc_dtype, lhs, rhs)
+        lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
+        rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
+        return self.binfunc(self.calc_dtype, lhs, rhs)
 
     def debug_repr(self):
         return 'Call2(%s, %s)' % (self.left.debug_repr(),
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -130,15 +130,18 @@
                                 "float_mul": 1, "int_add": 1,
                                 "int_ge": 1, "guard_false": 1, "jump": 1})
 
-    def test_max(self):
-        py.test.skip("broken, investigate")
-        result = self.run("""
+    def define_max():
+        return """
         a = |30|
         a[13] = 128
         b = a + a
         max(b)
-        """)
+        """
+
+    def test_max(self):
+        result = self.run("max")
         assert result == 256
+        py.test.skip("not there yet, getting though")
         self.check_simple_loop({"getinteriorfield_raw": 2, "float_add": 1,
                                 "float_mul": 1, "int_add": 1,
                                 "int_lt": 1, "guard_true": 1, "jump": 1})


More information about the pypy-commit mailing list