[pypy-commit] pypy numpy-comparison: Initial implementation (tests pass, translation fails)
snus_mumrik
noreply at buildbot.pypy.org
Fri Sep 2 14:45:55 CEST 2011
Author: Ilya Osadchiy <osadchiy.ilya at gmail.com>
Branch: numpy-comparison
Changeset: r47014:edb6c31894de
Date: 2011-09-02 10:37 +0300
http://bitbucket.org/pypy/pypy/changeset/edb6c31894de/
Log: Initial implementation (tests pass, translation fails)
diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -25,17 +25,18 @@
'floor': 'interp_ufuncs.floor',
'maximum': 'interp_ufuncs.maximum',
'minimum': 'interp_ufuncs.minimum',
- 'multiply': 'interp_ufuncs.multiply',
- 'negative': 'interp_ufuncs.negative',
- 'reciprocal': 'interp_ufuncs.reciprocal',
- 'sign': 'interp_ufuncs.sign',
- 'subtract': 'interp_ufuncs.subtract',
- 'sin': 'interp_ufuncs.sin',
- 'cos': 'interp_ufuncs.cos',
- 'tan': 'interp_ufuncs.tan',
- 'arcsin': 'interp_ufuncs.arcsin',
- 'arccos': 'interp_ufuncs.arccos',
- 'arctan': 'interp_ufuncs.arctan',
+ 'multiply': 'interp_ufuncs.multiply',
+ 'negative': 'interp_ufuncs.negative',
+ 'reciprocal': 'interp_ufuncs.reciprocal',
+ 'sign': 'interp_ufuncs.sign',
+ 'subtract': 'interp_ufuncs.subtract',
+ 'sin': 'interp_ufuncs.sin',
+ 'cos': 'interp_ufuncs.cos',
+ 'tan': 'interp_ufuncs.tan',
+ 'arcsin': 'interp_ufuncs.arcsin',
+ 'arccos': 'interp_ufuncs.arccos',
+ 'arctan': 'interp_ufuncs.arctan',
+ 'equal': 'interp_ufuncs.equal',
}
appleveldefs = {
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -125,6 +125,15 @@
))
return impl
+def bool_binop(func):
+ @functools.wraps(func)
+ def impl(self, v1, v2):
+ return self.box(func(self,
+ self.for_computation(self.unbox(v1)),
+ self.for_computation(self.unbox(v2)),
+ ))
+ return impl
+
def unaryop(func):
@functools.wraps(func)
def impl(self, v):
@@ -147,6 +156,25 @@
def div(self, v1, v2):
return v1 / v2
+ @bool_binop
+ def eq(self, v1, v2):
+ return v1 == v2
+ @bool_binop
+ def ne(self, v1, v2):
+ return v1 != v2
+ @bool_binop
+ def lt(self, v1, v2):
+ return v1 < v2
+ @bool_binop
+ def le(self, v1, v2):
+ return v1 <= v2
+ @bool_binop
+ def gt(self, v1, v2):
+ return v1 > v2
+ @bool_binop
+ def ge(self, v1, v2):
+ return v1 >= v2
+
@unaryop
def pos(self, v):
return +v
@@ -166,8 +194,8 @@
def bool(self, v):
return bool(self.for_computation(self.unbox(v)))
- def ne(self, v1, v2):
- return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2))
+# def ne(self, v1, v2):
+# return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2))
class FloatArithmeticDtype(ArithmaticTypeMixin):
@@ -355,4 +383,4 @@
num = interp_attrproperty("num", cls=W_Dtype),
kind = interp_attrproperty("kind", cls=W_Dtype),
shape = GetSetProperty(W_Dtype.descr_get_shape),
-)
\ No newline at end of file
+)
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -74,6 +74,13 @@
descr_pow = _binop_impl(interp_ufuncs.power)
descr_mod = _binop_impl(interp_ufuncs.mod)
+ descr_eq = _binop_impl(interp_ufuncs.equal)
+ descr_ne = _binop_impl(interp_ufuncs.not_equal)
+ descr_lt = _binop_impl(interp_ufuncs.less)
+ descr_le = _binop_impl(interp_ufuncs.less_equal)
+ descr_gt = _binop_impl(interp_ufuncs.greater)
+ descr_ge = _binop_impl(interp_ufuncs.greater_equal)
+
def _binop_right_impl(w_ufunc):
def impl(self, space, w_other):
w_other = scalar_w(space,
@@ -152,7 +159,7 @@
size=size, i=i, result=result,
cur_best=cur_best)
new_best = getattr(dtype, op_name)(cur_best, self.eval(i))
- if dtype.ne(new_best, cur_best):
+ if dtype.unbox(dtype.ne(new_best, cur_best)):
result = i
cur_best = new_best
i += 1
@@ -350,11 +357,12 @@
"""
Class for representing virtual arrays, such as binary ops or ufuncs
"""
- def __init__(self, signature, res_dtype):
+ def __init__(self, signature, res_dtype, calc_dtype):
BaseArray.__init__(self)
self.forced_result = None
self.signature = signature
self.res_dtype = res_dtype
+ self.calc_dtype = calc_dtype
def _del_sources(self):
# Function for deleting references to source arrays, to allow garbage-collecting them
@@ -402,7 +410,7 @@
class Call1(VirtualArray):
def __init__(self, signature, res_dtype, values):
- VirtualArray.__init__(self, signature, res_dtype)
+ VirtualArray.__init__(self, signature, res_dtype, res_dtype)
self.values = values
def _del_sources(self):
@@ -427,8 +435,8 @@
"""
Intermediate class for performing binary operations.
"""
- def __init__(self, signature, res_dtype, left, right):
- VirtualArray.__init__(self, signature, res_dtype)
+ def __init__(self, signature, res_dtype, calc_dtype, left, right):
+ VirtualArray.__init__(self, signature, res_dtype, calc_dtype)
self.left = left
self.right = right
@@ -444,14 +452,14 @@
return self.right.find_size()
def _eval(self, i):
- lhs = self.left.eval(i).convert_to(self.res_dtype)
- rhs = self.right.eval(i).convert_to(self.res_dtype)
+ lhs = self.left.eval(i).convert_to(self.calc_dtype)
+ rhs = self.right.eval(i).convert_to(self.calc_dtype)
sig = jit.promote(self.signature)
assert isinstance(sig, signature.Signature)
call_sig = sig.components[0]
assert isinstance(call_sig, signature.Call2)
- return call_sig.func(self.res_dtype, lhs, rhs)
+ return call_sig.func(self.calc_dtype, lhs, rhs).convert_to(self.res_dtype)
class ViewArray(BaseArray):
"""
@@ -610,6 +618,13 @@
__repr__ = interp2app(BaseArray.descr_repr),
__str__ = interp2app(BaseArray.descr_str),
+ __eq__ = interp2app(BaseArray.descr_eq),
+ __ne__ = interp2app(BaseArray.descr_ne),
+ __lt__ = interp2app(BaseArray.descr_lt),
+ __le__ = interp2app(BaseArray.descr_le),
+ __gt__ = interp2app(BaseArray.descr_gt),
+ __ge__ = interp2app(BaseArray.descr_ge),
+
dtype = GetSetProperty(BaseArray.descr_get_dtype),
shape = GetSetProperty(BaseArray.descr_get_shape),
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -24,9 +24,9 @@
return w_res
return func_with_new_name(impl, "%s_dispatcher" % func.__name__)
-def ufunc2(func=None, promote_to_float=False):
+def ufunc2(func=None, promote_to_float=False, bool_result=False):
if func is None:
- return lambda func: ufunc2(func, promote_to_float)
+ return lambda func: ufunc2(func, promote_to_float, bool_result)
call_sig = signature.Call2(func)
def impl(space, w_lhs, w_rhs):
@@ -35,17 +35,25 @@
w_lhs = convert_to_array(space, w_lhs)
w_rhs = convert_to_array(space, w_rhs)
- res_dtype = find_binop_result_dtype(space,
+ calc_dtype = find_binop_result_dtype(space,
w_lhs.find_dtype(), w_rhs.find_dtype(),
promote_to_float=promote_to_float,
)
+ # Some operations return bool regardless of input type
+ if bool_result:
+ res_dtype = space.fromcache(interp_dtype.W_BoolDtype)
+ else:
+ res_dtype = calc_dtype
if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
- return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
+ lhs = w_lhs.value.convert_to(calc_dtype)
+ rhs = w_rhs.value.convert_to(calc_dtype)
+ interm_res = func(calc_dtype, lhs, rhs)
+ return interm_res.convert_to(res_dtype).wrap(space)
new_sig = signature.Signature.find_sig([
call_sig, w_lhs.signature, w_rhs.signature
])
- w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
+ w_res = Call2(new_sig, res_dtype, calc_dtype, w_lhs, w_rhs)
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
return w_res
@@ -123,6 +131,13 @@
("maximum", "max", 2),
("minimum", "min", 2),
+ ("equal", "eq", 2, {"bool_result": True}),
+ ("not_equal", "ne", 2, {"bool_result": True}),
+ ("less", "lt", 2, {"bool_result": True}),
+ ("less_equal", "le", 2, {"bool_result": True}),
+ ("greater", "gt", 2, {"bool_result": True}),
+ ("greater_equal", "ge", 2, {"bool_result": True}),
+
("copysign", "copysign", 2, {"promote_to_float": True}),
("positive", "pos", 1),
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -510,6 +510,34 @@
assert array([1.2, 5]).dtype is dtype(float)
assert array([]).dtype is dtype(float)
+ def test_comparison(self):
+ from numpy import array, dtype
+ a = array(range(5))
+ b = array(range(5), dtype=float)
+ for func in [
+ lambda x, y: x == y,
+ lambda x, y: x != y,
+ lambda x, y: x < y,
+ lambda x, y: x <= y,
+ lambda x, y: x > y,
+ lambda x, y: x >= y,
+ ]:
+ _a3 = func (a, 3)
+ assert _a3.dtype is dtype(bool)
+ for i in xrange(5):
+ assert _a3[i] == (True if func(a[i], 3) else False)
+ _b3 = func (b, 3)
+ assert _b3.dtype is dtype(bool)
+ for i in xrange(5):
+ assert _b3[i] == (True if func(b[i], 3) else False)
+ _3a = func (3, a)
+ assert _3a.dtype is dtype(bool)
+ for i in xrange(5):
+ assert _3a[i] == (True if func(3, a[i]) else False)
+ _3b = func (3, b)
+ assert _3b.dtype is dtype(bool)
+ for i in xrange(5):
+ assert _3b[i] == (True if func(3, b[i]) else False)
class AppTestSupport(object):
def setup_class(cls):
@@ -522,4 +550,4 @@
a = fromstring(self.data)
for i in range(4):
assert a[i] == i + 1
- raises(ValueError, fromstring, "abc")
\ No newline at end of file
+ raises(ValueError, fromstring, "abc")
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -267,3 +267,11 @@
b = arctan(a)
assert math.isnan(b[0])
+ def test_comparison(self):
+ from numpy import array, dtype, equal
+ assert equal(3, 3) is True
+ assert equal(3, 4) is False
+ assert equal(3.0, 3.0) is True
+ assert equal(3.0, 3.5) is False
+ assert equal(3.0, 3) is True
+ assert equal(3.0, 4) is False
More information about the pypy-commit
mailing list