[Scipy-svn] r4870 - in trunk/scipy/interpolate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Nov 1 08:30:08 EDT 2008


Author: ptvirtan
Date: 2008-11-01 07:29:56 -0500 (Sat, 01 Nov 2008)
New Revision: 4870

Modified:
   trunk/scipy/interpolate/rbf.py
   trunk/scipy/interpolate/tests/test_rbf.py
Log:
interpolate.Rbf: accept array-like inputs in __call__. Accept complex data vector in __init__

Modified: trunk/scipy/interpolate/rbf.py
===================================================================
--- trunk/scipy/interpolate/rbf.py	2008-11-01 12:00:34 UTC (rev 4869)
+++ trunk/scipy/interpolate/rbf.py	2008-11-01 12:29:56 UTC (rev 4870)
@@ -43,7 +43,7 @@
 """
 
 from numpy import (sqrt, log, asarray, newaxis, all, dot, float64, exp, eye,
-                   isnan)
+                   isnan, float_)
 from scipy import linalg
 
 class Rbf(object):
@@ -117,9 +117,10 @@
             raise ValueError, 'Invalid basis function name'
 
     def __init__(self, *args, **kwargs):
-        self.xi = asarray([asarray(a, dtype=float64).flatten() for a in args[:-1]])
+        self.xi = asarray([asarray(a, dtype=float_).flatten()
+                           for a in args[:-1]])
         self.N = self.xi.shape[-1]
-        self.di = asarray(args[-1], dtype=float64).flatten()
+        self.di = asarray(args[-1]).flatten()
 
         assert [x.size==self.di.size for x in self.xi], \
                'All arrays must be equal length'
@@ -143,10 +144,11 @@
         return self.norm(x1, x2)
 
     def __call__(self, *args):
+        args = [asarray(x) for x in args]
         assert all([x.shape == y.shape \
                     for x in args \
                     for y in args]), 'Array lengths must be equal'
         shp = args[0].shape
-        self.xa = asarray([a.flatten() for a in args], dtype=float64)
+        self.xa = asarray([a.flatten() for a in args], dtype=float_)
         r = self._call_norm(self.xa, self.xi)
         return dot(self._function(r), self.nodes).reshape(shp)

Modified: trunk/scipy/interpolate/tests/test_rbf.py
===================================================================
--- trunk/scipy/interpolate/tests/test_rbf.py	2008-11-01 12:00:34 UTC (rev 4869)
+++ trunk/scipy/interpolate/tests/test_rbf.py	2008-11-01 12:29:56 UTC (rev 4870)
@@ -2,7 +2,7 @@
 # Created by John Travers, Robert Hetland, 2007
 """ Test functions for rbf module """
 
-from numpy.testing import assert_array_almost_equal
+from numpy.testing import assert_array_almost_equal, assert_almost_equal
 from numpy import linspace, sin, random, exp
 from scipy.interpolate.rbf import Rbf
 
@@ -15,11 +15,12 @@
     rbf = Rbf(x, y, function=function)
     yi = rbf(x)
     assert_array_almost_equal(y, yi)
+    assert_almost_equal(rbf(float(x[0])), y[0])
 
 def check_rbf2d(function):
     x = random.rand(50,1)*4-2
     y = random.rand(50,1)*4-2
-    z = x*exp(-x**2-y**2)
+    z = x*exp(-x**2-1j*y**2)
     rbf = Rbf(x, y, z, epsilon=2, function=function)
     zi = rbf(x, y)
     zi.shape = x.shape




More information about the Scipy-svn mailing list