[Scipy-svn] r6982 - trunk/scipy/signal/tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Nov 29 18:41:08 EST 2010
Author: warren.weckesser
Date: 2010-11-29 17:41:08 -0600 (Mon, 29 Nov 2010)
New Revision: 6982
Modified:
trunk/scipy/signal/tests/test_signaltools.py
Log:
TST: signal: make the decimal precision of the complex tests of the correlate function depend on the data dtype.
Modified: trunk/scipy/signal/tests/test_signaltools.py
===================================================================
--- trunk/scipy/signal/tests/test_signaltools.py 2010-11-29 14:57:55 UTC (rev 6981)
+++ trunk/scipy/signal/tests/test_signaltools.py 2010-11-29 23:41:08 UTC (rev 6982)
@@ -1,4 +1,4 @@
-#this program corresponds to special.py
+
from decimal import Decimal
from numpy.testing import TestCase, run_module_suite, assert_equal, \
@@ -12,6 +12,7 @@
from numpy import array, arange
import numpy as np
+
class _TestConvolve(TestCase):
def test_basic(self):
a = [3,4,5,6,5,4]
@@ -293,7 +294,10 @@
class TestWiener(TestCase):
def test_basic(self):
g = array([[5,6,4,3],[3,5,6,2],[2,3,5,6],[1,6,9,7]],'d')
- correct = array([[2.16374269,3.2222222222, 2.8888888889, 1.6666666667],[2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],[2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],[1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
+ correct = array([[2.16374269,3.2222222222, 2.8888888889, 1.6666666667],
+ [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
+ [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
+ [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
h = signal.wiener(g)
assert_array_almost_equal(h,correct,decimal=6)
@@ -449,8 +453,11 @@
class TestLinearFilterDecimal(_TestLinearFilter):
dt = np.dtype(Decimal)
+
class _TestCorrelateReal(TestCase):
+
dt = None
+
def _setup_rank1(self):
# a.size should be greated than b.size for the tests
a = np.linspace(0, 3, 4).astype(self.dt)
@@ -568,6 +575,7 @@
assert_array_almost_equal(y, y_r)
self.assertTrue(y.dtype == self.dt)
+
def _get_testcorrelate_class(i, base):
class TestCorrelateX(base):
dt = i
@@ -580,9 +588,19 @@
cls = _get_testcorrelate_class(i, _TestCorrelateReal)
globals()[cls.__name__] = cls
+
class _TestCorrelateComplex(TestCase):
+
+ # The numpy data type to use.
dt = None
+
+ # The decimal precision to be used for comparing results.
+ # This value will be passed as the 'decimal' keyword argument of
+ # assert_array_almost_equal().
+ decimal = None
+
def _setup_rank1(self, mode):
+ np.random.seed(9)
a = np.random.randn(10).astype(self.dt)
a += 1j * np.random.randn(10).astype(self.dt)
b = np.random.randn(8).astype(self.dt)
@@ -597,19 +615,19 @@
def test_rank1_valid(self):
a, b, y_r = self._setup_rank1('valid')
y = correlate(a, b, 'valid', old_behavior=False)
- assert_array_almost_equal(y, y_r)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal)
self.assertTrue(y.dtype == self.dt)
def test_rank1_same(self):
a, b, y_r = self._setup_rank1('same')
y = correlate(a, b, 'same', old_behavior=False)
- assert_array_almost_equal(y, y_r)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal)
self.assertTrue(y.dtype == self.dt)
def test_rank1_full(self):
a, b, y_r = self._setup_rank1('full')
y = correlate(a, b, 'full', old_behavior=False)
- assert_array_almost_equal(y, y_r)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal)
self.assertTrue(y.dtype == self.dt)
def test_rank3(self):
@@ -624,28 +642,28 @@
correlate(a.imag, b.real, old_behavior=False))
y = correlate(a, b, 'full', old_behavior=False)
- assert_array_almost_equal(y, y_r, decimal=4)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal-1)
self.assertTrue(y.dtype == self.dt)
@dec.deprecated()
def test_rank1_valid_old(self):
a, b, y_r = self._setup_rank1('valid')
y = correlate(b, a.conj(), 'valid')
- assert_array_almost_equal(y, y_r)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal)
self.assertTrue(y.dtype == self.dt)
@dec.deprecated()
def test_rank1_same_old(self):
a, b, y_r = self._setup_rank1('same')
y = correlate(b, a.conj(), 'same')
- assert_array_almost_equal(y, y_r)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal)
self.assertTrue(y.dtype == self.dt)
@dec.deprecated()
def test_rank1_full_old(self):
a, b, y_r = self._setup_rank1('full')
y = correlate(b, a.conj(), 'full')
- assert_array_almost_equal(y, y_r)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal)
self.assertTrue(y.dtype == self.dt)
@dec.deprecated()
@@ -661,13 +679,20 @@
correlate(a.imag, b.real, old_behavior=False))
y = correlate(b, a.conj(), 'full')
- assert_array_almost_equal(y, y_r, decimal=4)
+ assert_array_almost_equal(y, y_r, decimal=self.decimal-1)
self.assertTrue(y.dtype == self.dt)
-for i in [np.csingle, np.cdouble, np.clongdouble]:
+
+# Create three classes, one for each complex data type: TestCorrelateComplex64,
+# TestCorrelateComplex128 and TestCorrelateComplex256.
+# The second number in the pairs is used in the 'decimal' keyword argument of
+# the array comparisons in the tests.
+for i, decimal in [(np.csingle, 5), (np.cdouble, 10), (np.clongdouble, 15)]:
cls = _get_testcorrelate_class(i, _TestCorrelateComplex)
+ cls.decimal = decimal
globals()[cls.__name__] = cls
+
class TestFiltFilt:
def test_basic(self):
out = signal.filtfilt([1,2,3], [1,2,3], np.arange(12))
More information about the Scipy-svn
mailing list