[Scipy-svn] r6215 - trunk/scipy/interpolate

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Feb 7 02:13:04 EST 2010


Author: oliphant
Date: 2010-02-07 01:13:04 -0600 (Sun, 07 Feb 2010)
New Revision: 6215

Modified:
   trunk/scipy/interpolate/rbf.py
Log:
Add ability to use arbitrary basis function to Rbf constructor for radial basis function interpolation.

Modified: trunk/scipy/interpolate/rbf.py
===================================================================
--- trunk/scipy/interpolate/rbf.py	2010-02-05 04:31:54 UTC (rev 6214)
+++ trunk/scipy/interpolate/rbf.py	2010-02-07 07:13:04 UTC (rev 6215)
@@ -3,6 +3,7 @@
 Written by John Travers <jtravs at gmail.com>, February 2007
 Based closely on Matlab code by Alex Chirokov
 Additional, large, improvements by Robert Hetland
+Some additional alterations by Travis Oliphant
 
 Permission to use, modify, and distribute this software is given under the
 terms of the SciPy (BSD style) license.  See LICENSE.txt that came with
@@ -42,10 +43,11 @@
 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """
 
-from numpy import (sqrt, log, asarray, newaxis, all, dot, float64, exp, eye,
-                   isnan, float_)
+from numpy import (sqrt, log, asarray, newaxis, all, dot, exp, eye,
+                   float_)
 from scipy import linalg
 
+
 class Rbf(object):
     """
     Rbf(*args)
@@ -58,18 +60,22 @@
     *args : arrays
         x, y, z, ..., d, where x, y, z, ... are the coordinates of the nodes
         and d is the array of values at the nodes
-    function : str, optional
+    function : str or callable, optional
         The radial basis function, based on the radius, r, given by the norm
         (defult is Euclidean distance); the default is 'multiquadric'::
 
             'multiquadric': sqrt((r/self.epsilon)**2 + 1)
-            'inverse multiquadric': 1.0/sqrt((r/self.epsilon)**2 + 1)
+            'inverse': 1.0/sqrt((r/self.epsilon)**2 + 1)
             'gaussian': exp(-(r/self.epsilon)**2)
             'linear': r
             'cubic': r**3
             'quintic': r**5
-            'thin-plate': r**2 * log(r)
+            'thin_plate': r**2 * log(r)
 
+        If callable, then it must take 2 arguments (self, r).  The epsilon parameter
+        will be available as self.epsilon.  Other keyword arguments passed in will
+        be available as well.
+
     epsilon : float, optional
         Adjustable constant for gaussian or multiquadrics functions
         - defaults to approximate average distance between nodes (which is
@@ -99,26 +105,67 @@
     def _euclidean_norm(self, x1, x2):
         return sqrt( ((x1 - x2)**2).sum(axis=0) )
 
-    def _function(self, r):
-        if self.function.lower() == 'multiquadric':
+    def _h_multiquadric(self, r):
             return sqrt((1.0/self.epsilon*r)**2 + 1)
-        elif self.function.lower() == 'inverse multiquadric':
+    def _h_inverse_multiquadric(self, r):
             return 1.0/sqrt((1.0/self.epsilon*r)**2 + 1)
-        elif self.function.lower() == 'gaussian':
+    def _h_gaussian(self, r):
             return exp(-(1.0/self.epsilon*r)**2)
-        elif self.function.lower() == 'linear':
-            return r
-        elif self.function.lower() == 'cubic':
-            return r**3
-        elif self.function.lower() == 'quintic':
-            return r**5
-        elif self.function.lower() == 'thin-plate':
-            result = r**2 * log(r)
-            result[r == 0] = 0 # the spline is zero at zero
-            return result
-        else:
-            raise ValueError, 'Invalid basis function name'
+    def _h_linear(self, r):
+        return r
+    def _h_cubic(self, r):
+        return r**3
+    def _h_quintic(self, r):
+        return r**5
+    def _h_thin_plate(self, r):
+        result = r**2 * log(r)
+        result[r == 0] = 0 # the spline is zero at zero
+        return result
 
+    # Setup self._function and do smoke test on initial r
+    def _init_function(self, r):
+        if isinstance(self.function, str):
+           self.function = self.function.lower()
+           _mapped = {'inverse': 'inverse_multiquadric',
+                      'inverse multiquadric': 'inverse_multiquadric',
+                      'thin-plate': 'thin_plate'}                   
+           if self.function in _mapped:
+               self.function = _mapped[self.function]
+
+           func_name = "_h_" + self.function
+           if hasattr(self, func_name):
+               self._function = getattr(self, func_name)
+           else:
+               functionlist = [x[3:] for x in dir(self) if x.startswith('_h_')]
+               raise ValueError, "function must be a callable or one of ", \
+                   ", ".join(functionlist)               
+           self._function = getattr(self, "_h_"+self.function)
+        elif callable(self.function):
+            import new
+            allow_one = False
+            if hasattr(self.function, 'func_code'):
+                val = self.function
+                allow_one = True
+            elif hasattr(self.function, "im_func"):
+                val = self.function.im_func
+            elif hasattr(self.function, "__call__"):
+                val = self.function.__call__.im_func
+            else:
+                raise ValueError, "Cannot determine number of arguments to function"
+            
+            argcount = val.func_code.co_argcount
+            if allow_one and argcount == 1:
+                self._function = self.function
+            elif argcount == 2:
+                self._function = new.instancemethod(self.function, self, Rbf)
+            else:
+                raise ValueError, "Function argument must take 1 or 2 arguments."
+                
+        a0 = self._function(r)
+        if a0.shape != r.shape:
+            raise ValueError, "Callable must take array and return array of the same shape"
+        return a0
+
     def __init__(self, *args, **kwargs):
         self.xi = asarray([asarray(a, dtype=float_).flatten()
                            for a in args[:-1]])
@@ -131,10 +178,17 @@
         self.norm = kwargs.pop('norm', self._euclidean_norm)
         r = self._call_norm(self.xi, self.xi)
         self.epsilon = kwargs.pop('epsilon', r.mean())
-        self.function = kwargs.pop('function', 'multiquadric')
         self.smooth = kwargs.pop('smooth', 0.0)
 
-        self.A = self._function(r) - eye(self.N)*self.smooth
+        self.function = kwargs.pop('function', self._h_multiquadric)
+
+        # attach anything left in kwargs to self
+        #  for use by any user-callable function or 
+        #  to save on the object returned.
+        for item, value in kwargs.items():
+            setattr(self, item, value)
+   
+        self.A = self._init_function(r) - eye(self.N)*self.smooth
         self.nodes = linalg.solve(self.A, self.di)
 
     def _call_norm(self, x1, x2):




More information about the Scipy-svn mailing list