[Scipy-svn] r5055 - in trunk/scipy/interpolate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Nov 10 17:50:42 EST 2008
Author: ptvirtan
Date: 2008-11-10 16:50:30 -0600 (Mon, 10 Nov 2008)
New Revision: 5055
Modified:
trunk/scipy/interpolate/interpolate.py
trunk/scipy/interpolate/tests/test_interpolate.py
Log:
Make interp1d treat scalars differently from 1-d arrays (fixes #660)
Modified: trunk/scipy/interpolate/interpolate.py
===================================================================
--- trunk/scipy/interpolate/interpolate.py 2008-11-10 19:01:13 UTC (rev 5054)
+++ trunk/scipy/interpolate/interpolate.py 2008-11-10 22:50:30 UTC (rev 5055)
@@ -286,7 +286,7 @@
return result.reshape(x_new.shape+result.shape[1:])
def __call__(self, x_new):
- """ Find linearly interpolated y_new = f(x_new).
+ """Find interpolated y_new = f(x_new).
Parameters
----------
@@ -296,13 +296,14 @@
Returns
-------
y_new : number or array
- Linearly interpolated value(s) corresponding to x_new.
+ Interpolated value(s) corresponding to x_new.
+
"""
# 1. Handle values in x_new that are outside of x. Throw error,
# or return a list of mask array indicating the outofbounds values.
# The behavior is set by the bounds_error variable.
- x_new = atleast_1d(x_new)
+ x_new = asarray(x_new)
out_of_bounds = self._check_bounds(x_new)
y_new = self._call(x_new)
@@ -318,7 +319,15 @@
# and
# 7. Rotate the values back to their proper place.
- if self._kind == 'linear':
+ if nx == 0:
+ # special case: x is a scalar
+ if out_of_bounds:
+ if ny == 0:
+ return self.fill_value
+ else:
+ y_new[...] = self.fill_value
+ return y_new
+ elif self._kind == 'linear':
y_new[..., out_of_bounds] = self.fill_value
axes = range(ny - nx)
axes[self.axis:self.axis] = range(ny - nx, ny)
@@ -330,7 +339,7 @@
return y_new.transpose(axes)
def _check_bounds(self, x_new):
- """ Check the inputs for being in the bounds of the interpolated data.
+ """Check the inputs for being in the bounds of the interpolated data.
Parameters
----------
Modified: trunk/scipy/interpolate/tests/test_interpolate.py
===================================================================
--- trunk/scipy/interpolate/tests/test_interpolate.py 2008-11-10 19:01:13 UTC (rev 5054)
+++ trunk/scipy/interpolate/tests/test_interpolate.py 2008-11-10 22:50:30 UTC (rev 5055)
@@ -158,13 +158,17 @@
bounds_error=False, kind=kind)
assert_array_equal(
extrap10(11.2),
- np.array([self.fill_value]),
+ np.array(self.fill_value),
)
assert_array_equal(
extrap10(-3.4),
- np.array([self.fill_value]),
+ np.array(self.fill_value),
)
assert_array_equal(
+ extrap10([[[11.2], [-3.4], [12.6], [19.3]]]),
+ np.array(self.fill_value),
+ )
+ assert_array_equal(
extrap10._check_bounds(np.array([-1.0, 0.0, 5.0, 9.0, 11.0])),
np.array([True, False, False, False, True]),
)
@@ -193,7 +197,7 @@
interp210 = interp1d(self.x10, self.y210, kind=kind)
assert_array_almost_equal(
interp210(1.5),
- np.array([[1.5], [11.5]]),
+ np.array([1.5, 11.5]),
)
assert_array_almost_equal(
interp210(np.array([1.5, 2.4])),
@@ -204,7 +208,7 @@
interp102 = interp1d(self.x10, self.y102, axis=0, kind=kind)
assert_array_almost_equal(
interp102(1.5),
- np.array([[3.0, 4.0]]),
+ np.array([3.0, 4.0]),
)
assert_array_almost_equal(
interp102(np.array([1.5, 2.4])),
More information about the Scipy-svn
mailing list