[pypy-commit] pypy numpy-dtype-alt: fix for type coersion on unary ufuncs
alex_gaynor
noreply at buildbot.pypy.org
Sat Aug 20 23:15:04 CEST 2011
Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-dtype-alt
Changeset: r46674:612e65b4b4cb
Date: 2011-08-20 16:19 -0500
http://bitbucket.org/pypy/pypy/changeset/612e65b4b4cb/
Log: fix for type coersion on unary ufuncs
diff --git a/TODO.txt b/TODO.txt
--- a/TODO.txt
+++ b/TODO.txt
@@ -1,9 +1,9 @@
TODO for mering numpy-dtype-alt
===============================
-* Fix sin(<bool array>)
* Fix for raw memory with the JIT under llgraph.
* More operations on more dtypes
+* dtype guessing
* Any more attributes that need to be exposed at app-level
For later
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -69,7 +69,6 @@
def unerase(self, storage):
return rffi.cast(TP, storage)
- @specialize.argtype(1)
@enforceargs(None, valtype)
def box(self, value):
return Box(value)
@@ -286,21 +285,10 @@
def str_format(self, item):
return float2string(self.unbox(item), 'g', rfloat.DTSF_STR_PRECISION)
-W_Float16Dtype = create_low_level_dtype(
- num = 23, kind = FLOATINGLTR, name = "float16",
- aliases = [],
- applevel_types =[],
- T = rffi.USHORT,
- valtype = rffi.USHORT._type,
-)
-class W_Float16Dtype(W_Float16Dtype):
- def unwrap(self, space, w_item):
- return self.adapt_val(space.float_w(space.float(w_item)))
-
ALL_DTYPES = [
W_BoolDtype,
W_Int8Dtype, W_Int32Dtype, W_Int64Dtype,
- W_Float64Dtype, W_Float16Dtype
+ W_Float64Dtype
]
dtypes_by_alias = unrolling_iterable([
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -310,7 +310,7 @@
@specialize.arg(1)
def scalar_w(space, dtype, w_obj):
- return Scalar(scalar(space, dtype, w_obj))
+ return Scalar(space.fromcache(dtype), scalar(space, dtype, w_obj))
@specialize.arg(1)
def scalar(space, dtype, w_obj):
@@ -323,15 +323,16 @@
"""
signature = signature.BaseSignature()
- def __init__(self, value):
+ def __init__(self, dtype, value):
BaseArray.__init__(self)
+ self.dtype = dtype
self.value = value
def find_size(self):
raise ValueError
def find_dtype(self):
- raise ValueError
+ return self.dtype
def eval(self, i):
return self.value
@@ -340,10 +341,11 @@
"""
Class for representing virtual arrays, such as binary ops or ufuncs
"""
- def __init__(self, signature):
+ def __init__(self, signature, res_dtype):
BaseArray.__init__(self)
self.forced_result = None
self.signature = signature
+ self.res_dtype = res_dtype
def _del_sources(self):
# Function for deleting references to source arrays, to allow garbage-collecting them
@@ -383,14 +385,12 @@
return self._find_size()
def find_dtype(self):
- if self.forced_result is not None:
- return self.forced_result.find_dtype()
- return self._find_dtype()
+ return self.res_dtype
class Call1(VirtualArray):
- def __init__(self, signature, values):
- VirtualArray.__init__(self, signature)
+ def __init__(self, signature, res_dtype, values):
+ VirtualArray.__init__(self, signature, res_dtype)
self.values = values
def _del_sources(self):
@@ -400,42 +400,23 @@
return self.values.find_size()
def _find_dtype(self):
- return self.values.find_dtype()
+ return self.res_dtype
def _eval(self, i):
call_sig = self.signature.components[0]
assert isinstance(call_sig, signature.Call1)
- return call_sig.func(self.find_dtype(), self.values.eval(i))
+ val = self.values.eval(i).convert_to(self.res_dtype)
+ return call_sig.func(self.res_dtype, val)
class Call2(VirtualArray):
"""
Intermediate class for performing binary operations.
"""
- def __init__(self, space, signature, left, right):
- VirtualArray.__init__(self, signature)
+ def __init__(self, signature, res_dtype, left, right):
+ VirtualArray.__init__(self, signature, res_dtype)
self.left = left
self.right = right
- lhs_dtype = rhs_dtype = None
- try:
- lhs_dtype = self.left.find_dtype()
- except ValueError:
- pass
- try:
- rhs_dtype = self.right.find_dtype()
- except ValueError:
- pass
- if lhs_dtype is not None and rhs_dtype is not None:
- self.res_dtype = interp_ufuncs.find_binop_result_dtype(space,
- lhs_dtype, rhs_dtype
- )
- elif lhs_dtype is not None:
- self.res_dtype = lhs_dtype
- elif rhs_dtype is not None:
- self.res_dtype = rhs_dtype
- else:
- self.res_dtype = None
-
def _del_sources(self):
self.left = None
self.right = None
@@ -448,17 +429,11 @@
return self.right.find_size()
def _eval(self, i):
- dtype = self.find_dtype()
- lhs, rhs = self.left.eval(i), self.right.eval(i)
- lhs, rhs = lhs.convert_to(dtype), rhs.convert_to(dtype)
+ lhs = self.left.eval(i).convert_to(self.res_dtype)
+ rhs = self.right.eval(i).convert_to(self.res_dtype)
call_sig = self.signature.components[0]
assert isinstance(call_sig, signature.Call2)
- return call_sig.func(dtype, lhs, rhs)
-
- def _find_dtype(self):
- if self.res_dtype is not None:
- return self.res_dtype
- raise ValueError
+ return call_sig.func(self.res_dtype, lhs, rhs)
class ViewArray(BaseArray):
"""
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -19,7 +19,11 @@
return func(res_dtype, w_obj.value).wrap(space)
new_sig = signature.Signature.find_sig([call_sig, w_obj.signature])
- w_res = Call1(new_sig, w_obj)
+ res_dtype = find_unaryop_result_dtype(space,
+ w_obj.find_dtype(),
+ promote_to_float=promote_to_float,
+ )
+ w_res = Call1(new_sig, res_dtype, w_obj)
w_obj.add_invalidates(w_res)
return w_res
return func_with_new_name(impl, "%s_dispatcher" % func.__name__)
@@ -40,7 +44,11 @@
return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
new_sig = signature.Signature.find_sig([call_sig, w_lhs.signature, w_rhs.signature])
- w_res = Call2(space, new_sig, w_lhs, w_rhs)
+ res_dtype = find_binop_result_dtype(space,
+ w_lhs.find_dtype(), w_rhs.find_dtype(),
+ promote_to_float=promote_to_float,
+ )
+ w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
return w_res
diff --git a/pypy/module/micronumpy/test/test_base.py b/pypy/module/micronumpy/test/test_base.py
--- a/pypy/module/micronumpy/test/test_base.py
+++ b/pypy/module/micronumpy/test/test_base.py
@@ -11,11 +11,13 @@
class TestSignature(object):
def test_binop_signature(self, space):
- ar = SingleDimArray(10, dtype=space.fromcache(interp_dtype.W_Float64Dtype))
+ float64_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
+
+ ar = SingleDimArray(10, dtype=float64_dtype)
v1 = ar.descr_add(space, ar)
- v2 = ar.descr_add(space, Scalar(2.0))
+ v2 = ar.descr_add(space, Scalar(float64_dtype, 2.0))
assert v1.signature is not v2.signature
- v3 = ar.descr_add(space, Scalar(1.0))
+ v3 = ar.descr_add(space, Scalar(float64_dtype, 1.0))
assert v2.signature is v3.signature
v4 = ar.descr_add(space, ar)
assert v1.signature is v4.signature
@@ -63,7 +65,6 @@
bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
int8_dtype = space.fromcache(interp_dtype.W_Int8Dtype)
int32_dtype = space.fromcache(interp_dtype.W_Int32Dtype)
- float16_dtype = space.fromcache(interp_dtype.W_Float16Dtype)
float64_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
# Normal rules, everythign returns itself
@@ -72,8 +73,9 @@
assert find_unaryop_result_dtype(space, int32_dtype) is int32_dtype
assert find_unaryop_result_dtype(space, float64_dtype) is float64_dtype
- # Coerce to floats
- assert find_unaryop_result_dtype(space, bool_dtype, promote_to_float=True) is float16_dtype
- assert find_unaryop_result_dtype(space, int8_dtype, promote_to_float=True) is float16_dtype
+ # Coerce to floats, some of these will eventually be float16, or
+ # whatever our smallest float type is.
+ assert find_unaryop_result_dtype(space, bool_dtype, promote_to_float=True) is float64_dtype
+ assert find_unaryop_result_dtype(space, int8_dtype, promote_to_float=True) is float64_dtype
assert find_unaryop_result_dtype(space, int32_dtype, promote_to_float=True) is float64_dtype
assert find_unaryop_result_dtype(space, float64_dtype, promote_to_float=True) is float64_dtype
\ No newline at end of file
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -193,6 +193,10 @@
for i in range(len(a)):
assert b[i] == math.sin(a[i])
+ a = sin(array([True, False], dtype=bool))
+ assert a[0] == sin(1)
+ assert a[1] == 0.0
+
def test_cos(self):
import math
from numpy import array, cos
More information about the pypy-commit
mailing list