[Scipy-svn] r5594 - in trunk/scipy/integrate: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Feb 24 18:16:58 EST 2009


Author: jtravs
Date: 2009-02-24 17:16:53 -0600 (Tue, 24 Feb 2009)
New Revision: 5594

Modified:
   trunk/scipy/integrate/ode.py
   trunk/scipy/integrate/tests/test_integrate.py
Log:
Add a complex wrapper to ode class.


Modified: trunk/scipy/integrate/ode.py
===================================================================
--- trunk/scipy/integrate/ode.py	2009-02-24 19:53:18 UTC (rev 5593)
+++ trunk/scipy/integrate/ode.py	2009-02-24 23:16:53 UTC (rev 5594)
@@ -29,6 +29,13 @@
     y1 = integrator.integrate(t1,step=0,relax=0)
     flag = integrator.successful()
 
+class complex_ode
+-----------------
+
+This class has the same generic interface as ode, except it can handle complex
+f, y and Jacobians by transparently translating them into the equivalent
+real valued system. It supports the real valued solvers (i.e not zvode) and is
+an alternative to ode with the zvode solver, sometimes performing better. 
 """
 
 integrator_info = \
@@ -195,14 +202,14 @@
 # if myodeint.runner:
 #     IntegratorBase.integrator_classes.append(myodeint)
 
-__all__ = ['ode']
+__all__ = ['ode', 'complex_ode']
 __version__ = "$Id$"
 __docformat__ = "restructuredtext en"
 
 import re
 import warnings
 
-from numpy import asarray, array, zeros, int32, isscalar
+from numpy import asarray, array, zeros, int32, isscalar, real, imag
 
 import vode as _vode
 import dop as _dop
@@ -337,6 +344,72 @@
         self.jac_params = args
         return self
 
+class complex_ode(ode):
+    """ A wrapper of ode for complex systems. """
+
+    def __init__(self, f, jac=None):
+        """
+        Define equation y' = f(y,t), where y and f can be complex.
+
+        Parameters
+        ----------
+        f : f(t, y, *f_args)
+            Rhs of the equation. t is a scalar, y.shape == (n,).
+            f_args is set by calling set_f_params(*args)
+        jac : jac(t, y, *jac_args)
+            Jacobian of the rhs, jac[i,j] = d f[i] / d y[j]
+            jac_args is set by calling set_f_params(*args)
+        """
+        self.cf = f
+        self.cjac = jac
+        if jac is not None:
+            ode.__init__(self, self._wrap, self._wrap_jac)
+        else:
+            ode.__init__(self, self._wrap, None)
+            
+    def _wrap(self, t, y, *f_args):
+        f = self.cf(*((t, y[::2] + 1j*y[1::2]) + f_args))
+        self.tmp[::2] = real(f)
+        self.tmp[1::2] = imag(f)
+        return self.tmp
+        
+    def _wrap_jac(self, t, y, *jac_args):
+        jac = self.cjac(*((t, y[::2] + 1j*y[1::2]) + jac_args))
+        self.jac_tmp[1::2,1::2] = self.jac_tmp[::2,::2] = real(jac)
+        self.jac_tmp[1::2,::2] = imag(jac)
+        self.jac_tmp[::2,1::2] = -self.jac_tmp[1::2,::2]
+        return self.jac_tmp
+
+    def set_integrator(self, name, **integrator_params):
+        """
+        Set integrator by name.
+
+        Parameters
+        ----------
+        name : str
+            Name of the integrator
+        integrator_params :
+            Additional parameters for the integrator.
+        """
+        if name == 'zvode':
+            raise ValueError("zvode should be used with ode, not zode")
+        return ode.set_integrator(self, name, **integrator_params)
+
+    def set_initial_value(self, y, t=0.0):
+        """Set initial conditions y(t) = y."""
+        y = asarray(y)
+        self.tmp = zeros(y.size*2, 'float')
+        self.tmp[::2] = real(y)
+        self.tmp[1::2] = imag(y)
+        if self.cjac is not None:
+            self.jac_tmp = zeros((y.size*2, y.size*2), 'float')
+        return ode.set_initial_value(self, self.tmp, t)
+
+    def integrate(self, t, step=0, relax=0):
+        """Find y=y(t), set y as an initial condition, and return y."""
+        y = ode.integrate(self, t, step, relax)
+        return y[::2] + 1j*y[1::2]
+
 #------------------------------------------------------------------------------
 # ODE integrators
 #------------------------------------------------------------------------------

Modified: trunk/scipy/integrate/tests/test_integrate.py
===================================================================
--- trunk/scipy/integrate/tests/test_integrate.py	2009-02-24 19:53:18 UTC (rev 5593)
+++ trunk/scipy/integrate/tests/test_integrate.py	2009-02-24 23:16:53 UTC (rev 5594)
@@ -1,4 +1,4 @@
-# Authors: Nils Wagner, Ed Schofield, Pauli Virtanen
+# Authors: Nils Wagner, Ed Schofield, Pauli Virtanen, John Travers
 """
 Tests for numerical integration.
 """
@@ -8,7 +8,7 @@
                   allclose
 
 from numpy.testing import *
-from scipy.integrate import odeint, ode
+from scipy.integrate import odeint, ode, complex_ode
 
 #------------------------------------------------------------------------------
 # Test ODE integrators
@@ -75,6 +75,7 @@
             problem = problem_cls()
             if problem.cmplx: continue
             if problem.stiff: continue
+            if hasattr(problem, 'jac'): continue
             self._do_problem(problem, 'dopri5')
             
     def test_dop853(self):
@@ -83,8 +84,56 @@
             problem = problem_cls()
             if problem.cmplx: continue
             if problem.stiff: continue
+            if hasattr(problem, 'jac'): continue
             self._do_problem(problem, 'dop853')
 
+class TestComplexOde(TestCase):
+    """
+    Check integrate.complex_ode
+    """
+    def _do_problem(self, problem, integrator, method='adams'):
+
+        # ode has callback arguments in different order than odeint
+        f = lambda t, z: problem.f(z, t)
+        jac = None
+        if hasattr(problem, 'jac'):
+            jac = lambda t, z: problem.jac(z, t)
+        ig = complex_ode(f, jac)
+        ig.set_integrator(integrator,
+                          atol=problem.atol/10,
+                          rtol=problem.rtol/10,
+                          method=method)
+        ig.set_initial_value(problem.z0, t=0.0)
+        z = ig.integrate(problem.stop_t)
+
+        assert ig.successful(), (problem, method)
+        assert problem.verify(array([z]), problem.stop_t), (problem, method)
+
+    def test_vode(self):
+        """Check the vode solver"""
+        for problem_cls in PROBLEMS:
+            problem = problem_cls()
+            if not problem.stiff:
+                self._do_problem(problem, 'vode', 'adams')
+            else:
+                self._do_problem(problem, 'vode', 'bdf')
+
+    def test_dopri5(self):
+        """Check the dopri5 solver"""
+        for problem_cls in PROBLEMS:
+            problem = problem_cls()
+            if problem.stiff: continue
+            if hasattr(problem, 'jac'): continue
+            self._do_problem(problem, 'dopri5')
+            
+    def test_dop853(self):
+        """Check the dop853 solver"""
+        for problem_cls in PROBLEMS:
+            problem = problem_cls()
+            if problem.stiff: continue
+            if hasattr(problem, 'jac'): continue
+            self._do_problem(problem, 'dop853')
+
 #------------------------------------------------------------------------------
 # Test problems
 #------------------------------------------------------------------------------




More information about the Scipy-svn mailing list