[Scipy-svn] r6215 - trunk/scipy/interpolate
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun Feb 7 02:13:04 EST 2010
Author: oliphant
Date: 2010-02-07 01:13:04 -0600 (Sun, 07 Feb 2010)
New Revision: 6215
Modified:
trunk/scipy/interpolate/rbf.py
Log:
Add ability to use arbitrary basis function to Rbf constructor for radial basis function interpolation.
Modified: trunk/scipy/interpolate/rbf.py
===================================================================
--- trunk/scipy/interpolate/rbf.py 2010-02-05 04:31:54 UTC (rev 6214)
+++ trunk/scipy/interpolate/rbf.py 2010-02-07 07:13:04 UTC (rev 6215)
@@ -3,6 +3,7 @@
Written by John Travers <jtravs at gmail.com>, February 2007
Based closely on Matlab code by Alex Chirokov
Additional, large, improvements by Robert Hetland
+Some additional alterations by Travis Oliphant
Permission to use, modify, and distribute this software is given under the
terms of the SciPy (BSD style) license. See LICENSE.txt that came with
@@ -42,10 +43,11 @@
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
-from numpy import (sqrt, log, asarray, newaxis, all, dot, float64, exp, eye,
- isnan, float_)
+from numpy import (sqrt, log, asarray, newaxis, all, dot, exp, eye,
+ float_)
from scipy import linalg
+
class Rbf(object):
"""
Rbf(*args)
@@ -58,18 +60,22 @@
*args : arrays
x, y, z, ..., d, where x, y, z, ... are the coordinates of the nodes
and d is the array of values at the nodes
- function : str, optional
+ function : str or callable, optional
The radial basis function, based on the radius, r, given by the norm
(defult is Euclidean distance); the default is 'multiquadric'::
'multiquadric': sqrt((r/self.epsilon)**2 + 1)
- 'inverse multiquadric': 1.0/sqrt((r/self.epsilon)**2 + 1)
+ 'inverse': 1.0/sqrt((r/self.epsilon)**2 + 1)
'gaussian': exp(-(r/self.epsilon)**2)
'linear': r
'cubic': r**3
'quintic': r**5
- 'thin-plate': r**2 * log(r)
+ 'thin_plate': r**2 * log(r)
+ If callable, then it must take 2 arguments (self, r). The epsilon parameter
+ will be available as self.epsilon. Other keyword arguments passed in will
+ be available as well.
+
epsilon : float, optional
Adjustable constant for gaussian or multiquadrics functions
- defaults to approximate average distance between nodes (which is
@@ -99,26 +105,67 @@
def _euclidean_norm(self, x1, x2):
return sqrt( ((x1 - x2)**2).sum(axis=0) )
- def _function(self, r):
- if self.function.lower() == 'multiquadric':
+ def _h_multiquadric(self, r):
return sqrt((1.0/self.epsilon*r)**2 + 1)
- elif self.function.lower() == 'inverse multiquadric':
+ def _h_inverse_multiquadric(self, r):
return 1.0/sqrt((1.0/self.epsilon*r)**2 + 1)
- elif self.function.lower() == 'gaussian':
+ def _h_gaussian(self, r):
return exp(-(1.0/self.epsilon*r)**2)
- elif self.function.lower() == 'linear':
- return r
- elif self.function.lower() == 'cubic':
- return r**3
- elif self.function.lower() == 'quintic':
- return r**5
- elif self.function.lower() == 'thin-plate':
- result = r**2 * log(r)
- result[r == 0] = 0 # the spline is zero at zero
- return result
- else:
- raise ValueError, 'Invalid basis function name'
+ def _h_linear(self, r):
+ return r
+ def _h_cubic(self, r):
+ return r**3
+ def _h_quintic(self, r):
+ return r**5
+ def _h_thin_plate(self, r):
+ result = r**2 * log(r)
+ result[r == 0] = 0 # the spline is zero at zero
+ return result
+ # Setup self._function and do smoke test on initial r
+ def _init_function(self, r):
+ if isinstance(self.function, str):
+ self.function = self.function.lower()
+ _mapped = {'inverse': 'inverse_multiquadric',
+ 'inverse multiquadric': 'inverse_multiquadric',
+ 'thin-plate': 'thin_plate'}
+ if self.function in _mapped:
+ self.function = _mapped[self.function]
+
+ func_name = "_h_" + self.function
+ if hasattr(self, func_name):
+ self._function = getattr(self, func_name)
+ else:
+ functionlist = [x[3:] for x in dir(self) if x.startswith('_h_')]
+ raise ValueError, "function must be a callable or one of ", \
+ ", ".join(functionlist)
+ self._function = getattr(self, "_h_"+self.function)
+ elif callable(self.function):
+ import new
+ allow_one = False
+ if hasattr(self.function, 'func_code'):
+ val = self.function
+ allow_one = True
+ elif hasattr(self.function, "im_func"):
+ val = self.function.im_func
+ elif hasattr(self.function, "__call__"):
+ val = self.function.__call__.im_func
+ else:
+ raise ValueError, "Cannot determine number of arguments to function"
+
+ argcount = val.func_code.co_argcount
+ if allow_one and argcount == 1:
+ self._function = self.function
+ elif argcount == 2:
+ self._function = new.instancemethod(self.function, self, Rbf)
+ else:
+ raise ValueError, "Function argument must take 1 or 2 arguments."
+
+ a0 = self._function(r)
+ if a0.shape != r.shape:
+ raise ValueError, "Callable must take array and return array of the same shape"
+ return a0
+
def __init__(self, *args, **kwargs):
self.xi = asarray([asarray(a, dtype=float_).flatten()
for a in args[:-1]])
@@ -131,10 +178,17 @@
self.norm = kwargs.pop('norm', self._euclidean_norm)
r = self._call_norm(self.xi, self.xi)
self.epsilon = kwargs.pop('epsilon', r.mean())
- self.function = kwargs.pop('function', 'multiquadric')
self.smooth = kwargs.pop('smooth', 0.0)
- self.A = self._function(r) - eye(self.N)*self.smooth
+ self.function = kwargs.pop('function', self._h_multiquadric)
+
+ # attach anything left in kwargs to self
+ # for use by any user-callable function or
+ # to save on the object returned.
+ for item, value in kwargs.items():
+ setattr(self, item, value)
+
+ self.A = self._init_function(r) - eye(self.N)*self.smooth
self.nodes = linalg.solve(self.A, self.di)
def _call_norm(self, x1, x2):
More information about the Scipy-svn
mailing list