In this example I have created a sample Model class which is a callable class (you can use an instance of the class as the function parameter for odeint). The Model class is defined in model_cython.pyx and contains all the Cython code. In ode_script.py the model function is integrated with odeint and thus shows how to call the Cython class from a normal python script. Setup.py is needed to compile model_cython.pyx (see this post).
model_cython.pyx:
import sys #define some constants a c floats but still callable from python scripts #not really used here but left in here as an example cpdef float R_GAS,K_B R_GAS=8.31 #Gas constant J/molK K_B=1.38066e-23 #Boltzmann constant J/K #import some often used math functions from the c math library cdef extern from "math.h": double sin(double) cdef extern from "math.h": double fabs(double) cdef extern from "math.h": double exp(double) #define a function callable from c and python cpdef float Velocity(float t,float amplitude=1.0,float period=4.): cdef f1 f1=fabs(sin(t*period))*amplitude return(f1) class Model: """ This class can be used to integrate a differential equation """ def __init__(self,float a,float p): """ -a is the amplitude of the sine function that describes the speed oscillation as a function time -p controls the period of the sine function """ self.a=a self.p=p def __call__(self,float y,float t): """ y is the current state of the initial value problem: y: position t: time """ cdef float pos pos=y #Workhardening v=Velocity(t,amplitude=self.a,period=self.p)*exp(-pos) return(v)
ode_script.py:
#!/bin/env/python from numpy import r_ import pylab,scipy.integrate,sys sys.path.append(".") from model_cython import * #ampltiude and period amp=1.0 period=4.0 #set initial position pos0=0.0 #create instance of Model object with speed amplitude of 1 myModel=Model(amp,period) #create time vector (4 seconds in 500 steps) tvec=r_[0.0:4.0:500j] #integrate ov y=(scipy.integrate.odeint(myModel,pos0,tvec)) #plot results pylab.figure(1) pylab.clf() pylab.subplot(211) pylab.plot(tvec,y) pylab.xlabel("Time (s)") pylab.ylabel("Position (m)") pylab.grid(b=True) pylab.subplot(212) v_list=[] for t in tvec: v_list.append(Velocity(t,amplitude=amp,period=period)) pylab.plot(tvec,v_list) pylab.xlabel("Time (s)") pylab.ylabel("Undamped velocity (m/s)") pylab.grid(b=True) pylab.show()
setup.py:
from distutils.core import setup from distutils.extension import Extension from Cython.Distutils import build_ext ext_modules = [Extension("model_cython",["model_cython.pyx"],\ libraries=["m"])] setup( name= 'Generic model class', cmdclass = {'build_ext': build_ext}, ext_modules = ext_modules )
Hope this helps.