[Scipy-svn] r4870 - in trunk/scipy/interpolate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Sat Nov 1 08:30:08 EDT 2008
Author: ptvirtan
Date: 2008-11-01 07:29:56 -0500 (Sat, 01 Nov 2008)
New Revision: 4870
Modified:
trunk/scipy/interpolate/rbf.py
trunk/scipy/interpolate/tests/test_rbf.py
Log:
interpolate.Rbf: accept array-like inputs in __call__. Accept complex data vector in __init__
Modified: trunk/scipy/interpolate/rbf.py
===================================================================
--- trunk/scipy/interpolate/rbf.py 2008-11-01 12:00:34 UTC (rev 4869)
+++ trunk/scipy/interpolate/rbf.py 2008-11-01 12:29:56 UTC (rev 4870)
@@ -43,7 +43,7 @@
"""
from numpy import (sqrt, log, asarray, newaxis, all, dot, float64, exp, eye,
- isnan)
+ isnan, float_)
from scipy import linalg
class Rbf(object):
@@ -117,9 +117,10 @@
raise ValueError, 'Invalid basis function name'
def __init__(self, *args, **kwargs):
- self.xi = asarray([asarray(a, dtype=float64).flatten() for a in args[:-1]])
+ self.xi = asarray([asarray(a, dtype=float_).flatten()
+ for a in args[:-1]])
self.N = self.xi.shape[-1]
- self.di = asarray(args[-1], dtype=float64).flatten()
+ self.di = asarray(args[-1]).flatten()
assert [x.size==self.di.size for x in self.xi], \
'All arrays must be equal length'
@@ -143,10 +144,11 @@
return self.norm(x1, x2)
def __call__(self, *args):
+ args = [asarray(x) for x in args]
assert all([x.shape == y.shape \
for x in args \
for y in args]), 'Array lengths must be equal'
shp = args[0].shape
- self.xa = asarray([a.flatten() for a in args], dtype=float64)
+ self.xa = asarray([a.flatten() for a in args], dtype=float_)
r = self._call_norm(self.xa, self.xi)
return dot(self._function(r), self.nodes).reshape(shp)
Modified: trunk/scipy/interpolate/tests/test_rbf.py
===================================================================
--- trunk/scipy/interpolate/tests/test_rbf.py 2008-11-01 12:00:34 UTC (rev 4869)
+++ trunk/scipy/interpolate/tests/test_rbf.py 2008-11-01 12:29:56 UTC (rev 4870)
@@ -2,7 +2,7 @@
# Created by John Travers, Robert Hetland, 2007
""" Test functions for rbf module """
-from numpy.testing import assert_array_almost_equal
+from numpy.testing import assert_array_almost_equal, assert_almost_equal
from numpy import linspace, sin, random, exp
from scipy.interpolate.rbf import Rbf
@@ -15,11 +15,12 @@
rbf = Rbf(x, y, function=function)
yi = rbf(x)
assert_array_almost_equal(y, yi)
+ assert_almost_equal(rbf(float(x[0])), y[0])
def check_rbf2d(function):
x = random.rand(50,1)*4-2
y = random.rand(50,1)*4-2
- z = x*exp(-x**2-y**2)
+ z = x*exp(-x**2-1j*y**2)
rbf = Rbf(x, y, z, epsilon=2, function=function)
zi = rbf(x, y)
zi.shape = x.shape
More information about the Scipy-svn
mailing list