[pypy-commit] pypy refactor-signature: Improve a bit on max() bridge
fijal
noreply at buildbot.pypy.org
Thu Dec 15 10:19:56 CET 2011
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: refactor-signature
Changeset: r50536:dfe668607a47
Date: 2011-12-15 11:19 +0200
http://bitbucket.org/pypy/pypy/changeset/dfe668607a47/
Log: Improve a bit on max() bridge
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -882,7 +882,8 @@
def create_sig(self):
if self.forced_result is not None:
return self.forced_result.create_sig()
- return signature.Call2(self.ufunc, self.name, self.left.create_sig(),
+ return signature.Call2(self.ufunc, self.name, self.calc_dtype,
+ self.left.create_sig(),
self.right.create_sig())
class ViewArray(BaseArray):
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -68,7 +68,7 @@
promote_to_largest=True
)
shapelen = len(obj.shape)
- sig = find_sig(ReduceSignature(self.func, self.name,
+ sig = find_sig(ReduceSignature(self.func, self.name, dtype,
ScalarSignature(dtype),
obj.create_sig()))
frame = sig.create_frame(obj)
@@ -98,6 +98,8 @@
class W_Ufunc1(W_Ufunc):
argcount = 1
+ _immutable_fields_ = ["func", "name"]
+
def __init__(self, func, name, promote_to_float=False, promote_bools=False,
identity=None):
@@ -125,7 +127,7 @@
class W_Ufunc2(W_Ufunc):
- _immutable_fields_ = ["comparison_func", "func"]
+ _immutable_fields_ = ["comparison_func", "func", "name"]
argcount = 2
def __init__(self, func, name, promote_to_float=False, promote_bools=False,
@@ -158,7 +160,8 @@
)
new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
- w_res = Call2(self.func, self.name, new_shape, calc_dtype,
+ w_res = Call2(self.func, self.name,
+ new_shape, calc_dtype,
res_dtype, w_lhs, w_rhs)
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -61,6 +61,9 @@
self.iterators[i] = self.iterators[i].next(shapelen)
class Signature(object):
+ _attrs_ = ['iter_no']
+ _immutable_fields_ = ['iter_no']
+
def invent_numbering(self):
cache = r_dict(sigeq, sighash)
allnumbers = []
@@ -81,12 +84,15 @@
return NumpyEvalFrame(iterlist)
class ConcreteSignature(Signature):
+ _immutable_fields_ = ['dtype']
+
def __init__(self, dtype):
self.dtype = dtype
def eq(self, other):
if type(self) is not type(other):
return False
+ assert isinstance(other, ConcreteSignature)
return self.dtype is other.dtype
def hash(self):
@@ -108,7 +114,7 @@
arr = arr.get_concrete()
assert isinstance(arr, W_NDimArray)
iter = frame.iterators[self.iter_no]
- return arr.dtype.getitem(arr.storage, iter.offset)
+ return self.dtype.getitem(arr.storage, iter.offset)
class ScalarSignature(ConcreteSignature):
def debug_repr(self):
@@ -125,12 +131,15 @@
return arr.value
class ViewSignature(Signature):
+ _immutable_fields_ = ['child']
+
def __init__(self, child):
self.child = child
def eq(self, other):
if type(self) is not type(other):
return False
+ assert isinstance(other, ViewSignature)
return self.child.eq(other.child)
def hash(self):
@@ -164,6 +173,8 @@
raise NotImplementedError
class Call1(Signature):
+ _immutable_fields_ = ['unfunc', 'name', 'child']
+
def __init__(self, func, name, child):
self.unfunc = func
self.child = child
@@ -175,6 +186,7 @@
def eq(self, other):
if type(self) is not type(other):
return False
+ assert isinstance(other, Call1)
return self.unfunc is other.unfunc and self.child.eq(other.child)
def debug_repr(self):
@@ -195,11 +207,14 @@
return self.unfunc(arr.res_dtype, v)
class Call2(Signature):
- def __init__(self, func, name, left, right):
+ _immutable_fields_ = ['binfunc', 'name', 'calc_dtype', 'left', 'right']
+
+ def __init__(self, func, name, calc_dtype, left, right):
self.binfunc = func
self.left = left
self.right = right
self.name = name
+ self.calc_dtype = calc_dtype
def hash(self):
return (compute_hash(self.name) ^ (self.left.hash() << 1) ^
@@ -208,7 +223,9 @@
def eq(self, other):
if type(self) is not type(other):
return False
+ assert isinstance(other, Call2)
return (self.binfunc is other.binfunc and
+ self.calc_dtype is other.calc_dtype and
self.left.eq(other.left) and self.right.eq(other.right))
def _invent_numbering(self, cache, allnumbers):
@@ -225,9 +242,9 @@
def eval(self, frame, arr):
from pypy.module.micronumpy.interp_numarray import Call2
assert isinstance(arr, Call2)
- lhs = self.left.eval(frame, arr.left).convert_to(arr.calc_dtype)
- rhs = self.right.eval(frame, arr.right).convert_to(arr.calc_dtype)
- return self.binfunc(arr.calc_dtype, lhs, rhs)
+ lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
+ rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
+ return self.binfunc(self.calc_dtype, lhs, rhs)
def debug_repr(self):
return 'Call2(%s, %s)' % (self.left.debug_repr(),
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -130,15 +130,18 @@
"float_mul": 1, "int_add": 1,
"int_ge": 1, "guard_false": 1, "jump": 1})
- def test_max(self):
- py.test.skip("broken, investigate")
- result = self.run("""
+ def define_max():
+ return """
a = |30|
a[13] = 128
b = a + a
max(b)
- """)
+ """
+
+ def test_max(self):
+ result = self.run("max")
assert result == 256
+ py.test.skip("not there yet, getting though")
self.check_simple_loop({"getinteriorfield_raw": 2, "float_add": 1,
"float_mul": 1, "int_add": 1,
"int_lt": 1, "guard_true": 1, "jump": 1})
More information about the pypy-commit
mailing list