[Scipy-svn] r7111 - trunk/scipy/fftpack/tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Jan 31 16:15:09 EST 2011
Author: ptvirtan
Date: 2011-01-31 15:15:09 -0600 (Mon, 31 Jan 2011)
New Revision: 7111
Modified:
trunk/scipy/fftpack/tests/test_basic.py
trunk/scipy/fftpack/tests/test_pseudo_diffs.py
trunk/scipy/fftpack/tests/test_real_transforms.py
Log:
TST: fftpack: add tests checking fft routine overwrite behavior
Modified: trunk/scipy/fftpack/tests/test_basic.py
===================================================================
--- trunk/scipy/fftpack/tests/test_basic.py 2011-01-31 21:14:50 UTC (rev 7110)
+++ trunk/scipy/fftpack/tests/test_basic.py 2011-01-31 21:15:09 UTC (rev 7111)
@@ -652,5 +652,109 @@
except ValueError:
pass
+
+
+class TestOverwrite(object):
+ """
+ Check input overwrite behavior of the FFT functions
+ """
+
+ real_dtypes = [np.float32, np.float64]
+ dtypes = real_dtypes + [np.complex64, np.complex128]
+
+ def _check(self, x, routine, fftsize, axis):
+ x2 = x.copy()
+ y = routine(x2, fftsize, axis)
+
+ sig = "%s(%s%r, %r, axis=%r)" % (routine.__name__, x.dtype, x.shape,
+ fftsize, axis)
+ assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
+
+ def _check_1d(self, routine, dtype, shape, axis):
+ np.random.seed(1234)
+ if np.issubdtype(dtype, np.complexfloating):
+ data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+ else:
+ data = np.random.randn(*shape)
+ data = data.astype(dtype)
+
+ for fftsize in [8, 16, 32]:
+ self._check(data, routine, fftsize, axis)
+
+ def test_fft(self):
+ for dtype in self.dtypes:
+ self._check_1d(fft, dtype, (16,), -1)
+ self._check_1d(fft, dtype, (16, 2), 0)
+ self._check_1d(fft, dtype, (2, 16), 1)
+
+ def test_ifft(self):
+ for dtype in self.dtypes:
+ self._check_1d(ifft, dtype, (16,), -1)
+ self._check_1d(ifft, dtype, (16, 2), 0)
+ self._check_1d(ifft, dtype, (2, 16), 1)
+
+ def test_rfft(self):
+ for dtype in self.real_dtypes:
+ self._check_1d(rfft, dtype, (16,), -1)
+ self._check_1d(rfft, dtype, (16, 2), 0)
+ self._check_1d(rfft, dtype, (2, 16), 1)
+
+ def test_irfft(self):
+ for dtype in self.real_dtypes:
+ self._check_1d(irfft, dtype, (16,), -1)
+ self._check_1d(irfft, dtype, (16, 2), 0)
+ self._check_1d(irfft, dtype, (2, 16), 1)
+
+ def _check_nd_one(self, routine, dtype, shape, axes):
+ np.random.seed(1234)
+ if np.issubdtype(dtype, np.complexfloating):
+ data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+ else:
+ data = np.random.randn(*shape)
+ data = data.astype(dtype)
+
+ def fftshape_iter(shp):
+ if len(shp) <= 0:
+ yield ()
+ else:
+ for j in (shp[0]//2, shp[0], shp[0]*2):
+ for rest in fftshape_iter(shp[1:]):
+ yield (j,) + rest
+
+ if axes is None:
+ part_shape = shape
+ else:
+ part_shape = tuple(np.take(shape, axes))
+
+ for fftshape in fftshape_iter(part_shape):
+ self._check(data, routine, fftshape, axes)
+ if data.ndim > 1:
+ # check fortran order: it never overwrites
+ self._check(data.T, routine, fftshape, axes)
+
+ def _check_nd(self, routine, dtype):
+ self._check_nd_one(routine, dtype, (16,), None)
+ self._check_nd_one(routine, dtype, (16,), (0,))
+ self._check_nd_one(routine, dtype, (16, 2), (0,))
+ self._check_nd_one(routine, dtype, (2, 16), (1,))
+ self._check_nd_one(routine, dtype, (8, 16), None)
+ self._check_nd_one(routine, dtype, (8, 16), (0, 1))
+ self._check_nd_one(routine, dtype, (8, 16, 2), (0, 1))
+ self._check_nd_one(routine, dtype, (8, 16, 2), (1, 2))
+ self._check_nd_one(routine, dtype, (8, 16, 2), (0,))
+ self._check_nd_one(routine, dtype, (8, 16, 2), (1,))
+ self._check_nd_one(routine, dtype, (8, 16, 2), (2,))
+ self._check_nd_one(routine, dtype, (8, 16, 2), None)
+ self._check_nd_one(routine, dtype, (8, 16, 2), (0,1,2))
+
+ def test_fftn(self):
+ for dtype in self.dtypes:
+ self._check_nd(fftn, dtype)
+
+ def test_ifftn(self):
+ for dtype in self.dtypes:
+ self._check_nd(ifftn, dtype)
+
+
if __name__ == "__main__":
run_module_suite()
Modified: trunk/scipy/fftpack/tests/test_pseudo_diffs.py
===================================================================
--- trunk/scipy/fftpack/tests/test_pseudo_diffs.py 2011-01-31 21:14:50 UTC (rev 7110)
+++ trunk/scipy/fftpack/tests/test_pseudo_diffs.py 2011-01-31 21:15:09 UTC (rev 7111)
@@ -13,8 +13,10 @@
from numpy.testing import *
from scipy.fftpack import diff, fft, ifft, tilbert, itilbert, hilbert, \
- ihilbert, shift, fftfreq
+ ihilbert, shift, fftfreq, cs_diff, sc_diff, \
+ ss_diff, cc_diff
+import numpy as np
from numpy import arange, sin, cos, pi, exp, tanh, sum, sign
def random(size):
@@ -312,5 +314,68 @@
assert_array_almost_equal(shift(sin(x),pi/2),cos(x))
+class TestOverwrite(object):
+ """
+ Check input overwrite behavior
+ """
+
+ real_dtypes = [np.float32, np.float64]
+ dtypes = real_dtypes + [np.complex64, np.complex128]
+
+ def _check(self, x, routine, *args, **kwargs):
+ x2 = x.copy()
+ y = routine(x2, *args, **kwargs)
+ sig = routine.__name__
+ if args:
+ sig += repr(args)
+ if kwargs:
+ sig += repr(kwargs)
+ assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
+
+ def _check_1d(self, routine, dtype, shape, *args, **kwargs):
+ np.random.seed(1234)
+ if np.issubdtype(dtype, np.complexfloating):
+ data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+ else:
+ data = np.random.randn(*shape)
+ data = data.astype(dtype)
+ self._check(data, routine, *args, **kwargs)
+
+ def test_diff(self):
+ for dtype in self.dtypes:
+ self._check_1d(diff, dtype, (16,))
+
+ def test_tilbert(self):
+ for dtype in self.dtypes:
+ self._check_1d(tilbert, dtype, (16,), 1.6)
+
+ def test_itilbert(self):
+ for dtype in self.dtypes:
+ self._check_1d(itilbert, dtype, (16,), 1.6)
+
+ def test_hilbert(self):
+ for dtype in self.dtypes:
+ self._check_1d(hilbert, dtype, (16,))
+
+ def test_cs_diff(self):
+ for dtype in self.dtypes:
+ self._check_1d(cs_diff, dtype, (16,), 1.0, 4.0)
+
+ def test_sc_diff(self):
+ for dtype in self.dtypes:
+ self._check_1d(sc_diff, dtype, (16,), 1.0, 4.0)
+
+ def test_ss_diff(self):
+ for dtype in self.dtypes:
+ self._check_1d(ss_diff, dtype, (16,), 1.0, 4.0)
+
+ def test_cc_diff(self):
+ for dtype in self.dtypes:
+ self._check_1d(cc_diff, dtype, (16,), 1.0, 4.0)
+
+ def test_shift(self):
+ for dtype in self.dtypes:
+ self._check_1d(shift, dtype, (16,), 1.0)
+
if __name__ == "__main__":
run_module_suite()
Modified: trunk/scipy/fftpack/tests/test_real_transforms.py
===================================================================
--- trunk/scipy/fftpack/tests/test_real_transforms.py 2011-01-31 21:14:50 UTC (rev 7110)
+++ trunk/scipy/fftpack/tests/test_real_transforms.py 2011-01-31 21:15:09 UTC (rev 7111)
@@ -3,7 +3,7 @@
import numpy as np
from numpy.fft import fft as numfft
-from numpy.testing import assert_array_almost_equal, TestCase
+from numpy.testing import assert_array_almost_equal, assert_equal, TestCase
from scipy.fftpack.realtransforms import dct, idct
@@ -47,8 +47,8 @@
# XXX: we divide by np.max(y) because the tests fail otherwise. We
# should really use something like assert_array_approx_equal. The
# difference is due to fftw using a better algorithm w.r.t error
- # propagation compared to the ones from fftpack.
- assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
+ # propagation compared to the ones from fftpack.
+ assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
err_msg="Size %d failed" % i)
def test_axis(self):
@@ -144,8 +144,8 @@
# XXX: we divide by np.max(y) because the tests fail otherwise. We
# should really use something like assert_array_approx_equal. The
# difference is due to fftw using a better algorithm w.r.t error
- # propagation compared to the ones from fftpack.
- assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
+ # propagation compared to the ones from fftpack.
+ assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
err_msg="Size %d failed" % i)
class TestIDCTIDouble(_TestIDCTBase):
@@ -184,5 +184,46 @@
self.dec = 5
self.type = 3
+class TestOverwrite(object):
+ """
+ Check input overwrite behavior
+ """
+
+ real_dtypes = [np.float32, np.float64]
+
+ def _check(self, x, routine, type, fftsize, axis, norm):
+ x2 = x.copy()
+ y = routine(x2, type, fftsize, axis, norm)
+
+ sig = "%s(%s%r, %r, axis=%r)" % (
+ routine.__name__, x.dtype, x.shape, fftsize, axis)
+ assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
+
+ def _check_1d(self, routine, dtype, shape, axis):
+ np.random.seed(1234)
+ if np.issubdtype(dtype, np.complexfloating):
+ data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+ else:
+ data = np.random.randn(*shape)
+ data = data.astype(dtype)
+
+ for type in [1, 2, 3]:
+ for norm in [None, 'ortho']:
+ if type == 1 and norm == 'ortho':
+ continue
+ self._check(data, routine, type, None, axis, norm)
+
+ def test_dct(self):
+ for dtype in self.real_dtypes:
+ self._check_1d(dct, dtype, (16,), -1)
+ self._check_1d(dct, dtype, (16, 2), 0)
+ self._check_1d(dct, dtype, (2, 16), 1)
+
+ def test_idct(self):
+ for dtype in self.real_dtypes:
+ self._check_1d(idct, dtype, (16,), -1)
+ self._check_1d(idct, dtype, (16, 2), 0)
+ self._check_1d(idct, dtype, (2, 16), 1)
+
if __name__ == "__main__":
np.testing.run_module_suite()
More information about the Scipy-svn
mailing list