[Numpy-svn] r3977 - in trunk/numpy/core: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Aug 20 09:47:15 EDT 2007


Author: stefan
Date: 2007-08-20 08:46:55 -0500 (Mon, 20 Aug 2007)
New Revision: 3977

Modified:
   trunk/numpy/core/numeric.py
   trunk/numpy/core/tests/test_numeric.py
Log:
Fix allclose and add tests (based on a patch by Matthew Brett).


Modified: trunk/numpy/core/numeric.py
===================================================================
--- trunk/numpy/core/numeric.py	2007-08-20 13:43:21 UTC (rev 3976)
+++ trunk/numpy/core/numeric.py	2007-08-20 13:46:55 UTC (rev 3977)
@@ -835,21 +835,16 @@
     """
     x = array(a, copy=False)
     y = array(b, copy=False)
-    d1 = less_equal(absolute(x-y), atol + rtol * absolute(y))
     xinf = isinf(x)
-    yinf = isinf(y)
-    if (not xinf.any() and not yinf.any()):
-        return d1.all()
-    d3 = (x[xinf] == y[yinf])
-    d4 = (~xinf & ~yinf)
-    if d3.size < 2:
-        if d3.size==0:
-            return False
-        return d3
-    if d3.all():
-        return d1[d4].all() 
-    else:
+    if not all(xinf == isinf(y)):
         return False
+    if not any(xinf):
+        return all(less_equal(absolute(x-y), atol + rtol * absolute(y)))
+    if not all(x[xinf] == y[xinf]):
+        return False
+    x = x[~xinf]
+    y = y[~xinf]
+    return all(less_equal(absolute(x-y), atol + rtol * absolute(y)))
 
 def array_equal(a1, a2):
     try:

Modified: trunk/numpy/core/tests/test_numeric.py
===================================================================
--- trunk/numpy/core/tests/test_numeric.py	2007-08-20 13:43:21 UTC (rev 3976)
+++ trunk/numpy/core/tests/test_numeric.py	2007-08-20 13:46:55 UTC (rev 3977)
@@ -668,7 +668,62 @@
         self.clip(a, m, M, ac)
         assert_array_strict_equal(a, ac)
 
-        
+class test_allclose_inf(ParametricTestCase):
+    rtol = 1e-5
+    atol = 1e-8
+
+    def tst_allclose(self,x,y):
+        assert allclose(x,y), "%s and %s not close" % (x,y)
+
+    def tst_not_allclose(self,x,y):
+        assert not allclose(x,y), "%s and %s shouldn't be close" % (x,y)
+
+    def testip_allclose(self):
+        """Parametric test factory."""
+        arr = array([100,1000])
+        aran = arange(125).reshape((5,5,5))
+
+        atol = self.atol
+        rtol = self.rtol
+
+        data = [([1,0], [1,0]),
+                ([atol], [0]),
+                ([1], [1+rtol+atol]),
+                (arr, arr + arr*rtol),
+                (arr, arr + arr*rtol + atol*2),
+                (aran, aran + aran*rtol),]
+
+        for (x,y) in data:
+            yield (self.tst_allclose,x,y)
+
+    def testip_not_allclose(self):
+        """Parametric test factory."""
+        aran = arange(125).reshape((5,5,5))
+
+        atol = self.atol
+        rtol = self.rtol
+
+        data = [([inf,0], [1,inf]),
+                ([inf,0], [1,0]),
+                ([inf,inf], [1,inf]),
+                ([inf,inf], [1,0]),
+                ([-inf, 0], [inf, 0]),
+                ([nan,0], [nan,0]),
+                ([atol*2], [0]),
+                ([1], [1+rtol+atol*2]),
+                (aran, aran + aran*atol + atol*2),
+                (array([inf,1]), array([0,inf]))]
+
+        for (x,y) in data:
+            yield (self.tst_not_allclose,x,y)
+
+    def test_no_parameter_modification(self):
+        x = array([inf,1])
+        y = array([0,inf])
+        allclose(x,y)
+        assert_array_equal(x,array([inf,1]))
+        assert_array_equal(y,array([0,inf]))
+
 import sys
 if sys.version_info[:2] >= (2, 5):
     set_local_path()




More information about the Numpy-svn mailing list