In this example I have created a sample Model class which is a callable class (you can use an instance of the class as the function parameter for odeint). The Model class is defined in model_cython.pyx and contains all the Cython code. In ode_script.py the model function is integrated with odeint and thus shows how to call the Cython class from a normal python script. Setup.py is needed to compile model_cython.pyx (see this post).
model_cython.pyx:
import sys
#define some constants a c floats but still callable from python scripts
#not really used here but left in here as an example
cpdef float R_GAS,K_B
R_GAS=8.31 #Gas constant J/molK
K_B=1.38066e-23 #Boltzmann constant J/K
#import some often used math functions from the c math library
cdef extern from "math.h":
double sin(double)
cdef extern from "math.h":
double fabs(double)
cdef extern from "math.h":
double exp(double)
#define a function callable from c and python
cpdef float Velocity(float t,float amplitude=1.0,float period=4.):
cdef f1
f1=fabs(sin(t*period))*amplitude
return(f1)
class Model:
"""
This class can be used to integrate a differential equation
"""
def __init__(self,float a,float p):
"""
-a is the amplitude of the sine function that describes the
speed oscillation as a function time
-p controls the period of the sine function
"""
self.a=a
self.p=p
def __call__(self,float y,float t):
"""
y is the current state of the initial value problem:
y: position
t: time
"""
cdef float pos
pos=y
#Workhardening
v=Velocity(t,amplitude=self.a,period=self.p)*exp(-pos)
return(v)
ode_script.py:
#!/bin/env/python
from numpy import r_
import pylab,scipy.integrate,sys
sys.path.append(".")
from model_cython import *
#ampltiude and period
amp=1.0
period=4.0
#set initial position
pos0=0.0
#create instance of Model object with speed amplitude of 1
myModel=Model(amp,period)
#create time vector (4 seconds in 500 steps)
tvec=r_[0.0:4.0:500j]
#integrate ov
y=(scipy.integrate.odeint(myModel,pos0,tvec))
#plot results
pylab.figure(1)
pylab.clf()
pylab.subplot(211)
pylab.plot(tvec,y)
pylab.xlabel("Time (s)")
pylab.ylabel("Position (m)")
pylab.grid(b=True)
pylab.subplot(212)
v_list=[]
for t in tvec:
v_list.append(Velocity(t,amplitude=amp,period=period))
pylab.plot(tvec,v_list)
pylab.xlabel("Time (s)")
pylab.ylabel("Undamped velocity (m/s)")
pylab.grid(b=True)
pylab.show()
setup.py:
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_modules = [Extension("model_cython",["model_cython.pyx"],\
libraries=["m"])]
setup(
name= 'Generic model class',
cmdclass = {'build_ext': build_ext},
ext_modules = ext_modules
)
Hope this helps.