[Numpy-svn] r3977 - in trunk/numpy/core: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Aug 20 09:47:15 EDT 2007
Author: stefan
Date: 2007-08-20 08:46:55 -0500 (Mon, 20 Aug 2007)
New Revision: 3977
Modified:
trunk/numpy/core/numeric.py
trunk/numpy/core/tests/test_numeric.py
Log:
Fix allclose and add tests (based on a patch by Matthew Brett).
Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py 2007-08-20 13:43:21 UTC (rev 3976)
+++ trunk/numpy/core/numeric.py 2007-08-20 13:46:55 UTC (rev 3977)
@@ -835,21 +835,16 @@
"""
x = array(a, copy=False)
y = array(b, copy=False)
- d1 = less_equal(absolute(x-y), atol + rtol * absolute(y))
xinf = isinf(x)
- yinf = isinf(y)
- if (not xinf.any() and not yinf.any()):
- return d1.all()
- d3 = (x[xinf] == y[yinf])
- d4 = (~xinf & ~yinf)
- if d3.size < 2:
- if d3.size==0:
- return False
- return d3
- if d3.all():
- return d1[d4].all()
- else:
+ if not all(xinf == isinf(y)):
return False
+ if not any(xinf):
+ return all(less_equal(absolute(x-y), atol + rtol * absolute(y)))
+ if not all(x[xinf] == y[xinf]):
+ return False
+ x = x[~xinf]
+ y = y[~xinf]
+ return all(less_equal(absolute(x-y), atol + rtol * absolute(y)))
def array_equal(a1, a2):
try:
Modified: trunk/numpy/core/tests/test_numeric.py
===================================================================
--- trunk/numpy/core/tests/test_numeric.py 2007-08-20 13:43:21 UTC (rev 3976)
+++ trunk/numpy/core/tests/test_numeric.py 2007-08-20 13:46:55 UTC (rev 3977)
@@ -668,7 +668,62 @@
self.clip(a, m, M, ac)
assert_array_strict_equal(a, ac)
-
+class test_allclose_inf(ParametricTestCase):
+ rtol = 1e-5
+ atol = 1e-8
+
+ def tst_allclose(self,x,y):
+ assert allclose(x,y), "%s and %s not close" % (x,y)
+
+ def tst_not_allclose(self,x,y):
+ assert not allclose(x,y), "%s and %s shouldn't be close" % (x,y)
+
+ def testip_allclose(self):
+ """Parametric test factory."""
+ arr = array([100,1000])
+ aran = arange(125).reshape((5,5,5))
+
+ atol = self.atol
+ rtol = self.rtol
+
+ data = [([1,0], [1,0]),
+ ([atol], [0]),
+ ([1], [1+rtol+atol]),
+ (arr, arr + arr*rtol),
+ (arr, arr + arr*rtol + atol*2),
+ (aran, aran + aran*rtol),]
+
+ for (x,y) in data:
+ yield (self.tst_allclose,x,y)
+
+ def testip_not_allclose(self):
+ """Parametric test factory."""
+ aran = arange(125).reshape((5,5,5))
+
+ atol = self.atol
+ rtol = self.rtol
+
+ data = [([inf,0], [1,inf]),
+ ([inf,0], [1,0]),
+ ([inf,inf], [1,inf]),
+ ([inf,inf], [1,0]),
+ ([-inf, 0], [inf, 0]),
+ ([nan,0], [nan,0]),
+ ([atol*2], [0]),
+ ([1], [1+rtol+atol*2]),
+ (aran, aran + aran*atol + atol*2),
+ (array([inf,1]), array([0,inf]))]
+
+ for (x,y) in data:
+ yield (self.tst_not_allclose,x,y)
+
+ def test_no_parameter_modification(self):
+ x = array([inf,1])
+ y = array([0,inf])
+ allclose(x,y)
+ assert_array_equal(x,array([inf,1]))
+ assert_array_equal(y,array([0,inf]))
+
import sys
if sys.version_info[:2] >= (2, 5):
set_local_path()
More information about the Numpy-svn
mailing list