[pypy-commit] pypy default: implement comparison funcs for record types
bdkearns
noreply at buildbot.pypy.org
Sat Feb 22 22:15:04 CET 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r69273:9ba1d3bb478e
Date: 2014-02-22 15:52 -0500
http://bitbucket.org/pypy/pypy/changeset/9ba1d3bb478e/
Log: implement comparison funcs for record types
diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -5,6 +5,7 @@
import re
+from pypy.interpreter import special
from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
from pypy.interpreter.error import OperationError
from pypy.module.micronumpy import interp_boxes
@@ -74,6 +75,7 @@
def __init__(self):
"""NOT_RPYTHON"""
self.fromcache = InternalSpaceCache(self).getorbuild
+ self.w_NotImplemented = special.NotImplemented(self)
def _freeze_(self):
return True
@@ -194,6 +196,9 @@
def is_w(self, w_obj, w_what):
return w_obj is w_what
+ def eq_w(self, w_obj, w_what):
+ return w_obj == w_what
+
def issubtype(self, w_type1, w_type2):
return BoolObject(True)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -371,17 +371,23 @@
w_ldtype = w_lhs.get_dtype()
w_rdtype = w_rhs.get_dtype()
if w_ldtype.is_str_type() and w_rdtype.is_str_type() and \
- self.comparison_func:
+ self.comparison_func:
pass
elif (w_ldtype.is_str_type() or w_rdtype.is_str_type()) and \
- self.comparison_func and w_out is None:
+ self.comparison_func and w_out is None:
return space.wrap(False)
- elif (w_ldtype.is_flexible_type() or \
- w_rdtype.is_flexible_type()):
- raise OperationError(space.w_TypeError, space.wrap(
- 'unsupported operand dtypes %s and %s for "%s"' % \
- (w_rdtype.get_name(), w_ldtype.get_name(),
- self.name)))
+ elif w_ldtype.is_flexible_type() or w_rdtype.is_flexible_type():
+ if self.comparison_func:
+ if self.name == 'equal' or self.name == 'not_equal':
+ res = w_ldtype.eq(space, w_rdtype)
+ if not res:
+ return space.wrap(self.name == 'not_equal')
+ else:
+ return space.w_NotImplemented
+ else:
+ raise oefmt(space.w_TypeError,
+ 'unsupported operand dtypes %s and %s for "%s"',
+ w_rdtype.name, w_ldtype.name, self.name)
if self.are_common_types(w_ldtype, w_rdtype):
if not w_lhs.is_scalar() and w_rhs.is_scalar():
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -3573,6 +3573,28 @@
exc = raises(ValueError, "a.view(('float32', 2))")
assert exc.value[0] == 'new type not compatible with array.'
+ def test_record_ufuncs(self):
+ import numpy as np
+ a = np.zeros(3, dtype=[('a', 'i8'), ('b', 'i8')])
+ b = np.zeros(3, dtype=[('a', 'i8'), ('b', 'i8')])
+ c = np.zeros(3, dtype=[('a', 'f8'), ('b', 'f8')])
+ d = np.ones(3, dtype=[('a', 'i8'), ('b', 'i8')])
+ e = np.ones(3, dtype=[('a', 'i8'), ('b', 'i8'), ('c', 'i8')])
+ exc = raises(TypeError, abs, a)
+ assert exc.value[0] == 'Not implemented for this type'
+ assert (a == a).all()
+ assert not (a != a).any()
+ assert (a == b).all()
+ assert not (a != b).any()
+ assert a != c
+ assert not a == c
+ assert (a != d).all()
+ assert not (a == d).any()
+ assert a != e
+ assert not a == e
+ assert np.greater(a, a) is NotImplemented
+ assert np.less_equal(a, a) is NotImplemented
+
class AppTestPyPy(BaseNumpyAppTest):
def setup_class(cls):
if option.runappdirect and '__pypy__' not in sys.builtin_module_names:
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -1944,6 +1944,20 @@
pieces.append(")")
return "".join(pieces)
+ def eq(self, v1, v2):
+ assert isinstance(v1, interp_boxes.W_VoidBox)
+ assert isinstance(v2, interp_boxes.W_VoidBox)
+ s1 = v1.dtype.get_size()
+ s2 = v2.dtype.get_size()
+ assert s1 == s2
+ for i in range(s1):
+ if v1.arr.storage[v1.ofs + i] != v2.arr.storage[v2.ofs + i]:
+ return False
+ return True
+
+ def ne(self, v1, v2):
+ return not self.eq(v1, v2)
+
for tp in [Int32, Int64]:
if tp.T == lltype.Signed:
IntP = tp
More information about the pypy-commit
mailing list