Cython使用教程(2)
接上文:Cython使用教程(1)
接下来我们对代码进行重构,需要重构的代码用于计算混沌摆的.
未重构的代码如下:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from math import pi,sin
from multiprocess import Pool
import json
import time
from tqdm import *
g = 9.8
class pendulum():
def __init__(self,l = 9.8,Fd =1.44,omegad = 2/3.0,dt = 0.001, q = 0.5, theta0 = 0.2):
self.l = l
self.Fd = Fd
self.omegad = omegad
self.dt = dt
self.q = q
self.theta0 = theta0
def calculate(self, t0 = 50):
self.theta = []
self.omega = []
self.t = []
k = g/self.l
w = 0
a = self.theta0
t_ = 0
while t_ <= t0:
self.omega.append(w)
self.theta.append(a)
self.t.append(t_)
w += (-k*sin(a)-self.q*w+self.Fd*sin(self.omegad*t_)) * self.dt
a += w * self.dt
if a > pi:
a -= 2*pi
elif a < -pi:
a += 2*pi
t_ += self.dt
def bifurcationDiagram(self):
x,y,cx,cy = [],[],[],[]
for fd in tqdm(np.linspace(1.35,1.48,200)):
self.Fd = fd
self.calculate(800*pi/self.omegad + 2)
index = int((600*pi/self.omegad)/self.dt)
flag = 600*pi
for i in range(index, len(self.t)):
if (self.omegad * self.t[i] - flag - 2 * pi) > 0:
if np.abs(self.omegad * self.t[i - 1] - flag) < np.abs(self.omegad * self.t[i] - flag):
x.append(fd)
y.append(self.theta[i - 1])
else:
x.append(fd)
y.append(self.theta[i])
flag += 2*pi
print(len(x))
plt.scatter(x,y,s = 5,color = 'blue', marker = '.')
'''
with open('./{}-{}.txt'.format(self.omegad, self.dt),'w') as fp:
json.dump([x,y], fp)
'''
#plt.scatter(cx,cy,color='red', s = 8)
plt.title("$\Omega_d$={}, dt={}".format(self.omegad,self.dt))
#plt.show()
plt.savefig("./{}-{}.png".format(self.omegad, self.dt))
def main():
wdList = np.linspace(1/5, 2, 200)
pList = []
pool = Pool(processes=64)
for i in range(len(wdList)):
pList.append(pendulum(omegad=wdList[i],dt = 0.001))
pool.apply_async(pList[i].bifurcationDiagram)
print('Submit success!')
pool.close()
pool.join()
if __name__ == '__main__':
main()
这段代码实现了对混沌摆的计算,并绘制系统分叉图.但是此情况下计算效率较低,使用64核196G的配置进行计算,仍需要约14小时才能全部绘制,如果希望制作较为连贯的动图,则需要更高的采样率.时间成本进一步上升.
因此,我们想通过Cython来提升计算效率,节约时间成本.起初,我想通过Cython直接对代码进行重构,但是Cython对OOP的支持不是很好.
cdef self.l = l
无法通过编译,报错:
main_Cython.pyx:11:20: Syntax error in C variable declaration
后来,在网上查找的一个解决方法:
cdef double l(self): return l
仍然报错:
main_Cython.pyx:11:23: C function definition not allowed here
是在没有办法,对代码进行了全重构,去除了类,而是直接通过两个函数进行计算,最后代码为:
# FileName: main_Cython.pyx
from libc.math cimport sin
import json
import matplotlib.pyplot as plt
import numpy as np
cdef calculate(double l=9.8, double Fd = 1.44,double omegad = 2/3.0, double dt = 0.001, double q = 0.5, double theta0 = 0.2, double t0 = 50):
cdef theta = []
cdef omega = []
cdef t = []
cdef double k,w,a,t_,pi,g
g = 9.8
pi = 3.141592653
t = []
k = g/l
w = 0
a = theta0
t_ = 0
while t_ <= t0:
omega.append(w)
theta.append(a)
t.append(t_)
w += (-k*sin(a)-q*w+Fd*sin(omegad*t_)) * dt
a += w * dt
if a > pi:
a -= 2*pi
elif a < -pi:
a += 2*pi
t_ += dt
return theta, omega, t
def bifurcationDiagram(double omegad, double dt):
cdef double pi, flag, fd
cdef int i, index
pi = 3.141592653
cdef x = []
cdef y = []
for fd in np.linspace(1.35,1.48,200):
a,w,t = calculate(9.8, fd, omegad, dt, 0.5, 0.2, 800*pi/omegad + 2)
index = int((600*pi/omegad)/dt)
flag = 600*pi
for i in range(index, len(t)):
if (omegad * t[i] - flag - 2 * pi) > 0:
x.append(fd)
y.append(a[i])
flag += 2*pi
plt.scatter(x,y,s = 5)
plt.title("$\Omega_d$={}, dt={}".format(omegad,dt))
plt.savefig("./{}-{}.png".format(omegad, dt))
计算代码:
from main_Cython import bifurcationDiagram
from multiprocessing import Pool
import numpy as np
import matplotlib.pyplot as plt
def main():
wdList = np.linspace(1/5, 2, 1000)
pool = Pool(processes=16)
for i in range(len(wdList)):
pool.apply_async(bifurcationDiagram, args = (wdList[i], 0.001))
print('Submit success!')
pool.close()
pool.join()
if __name__ == '__main__':
main()
setup.py中有一个小技巧,通过加一个参数ext_modules = cythonize("main_Cython.pyx", annotate=True),可以生成一个.html文件,显示了与python交互的代码.理论上与python交互的越少,代码运算越快(numpy除外,这个库的优化是真的恐怖,比随手写的cpp还要快).
查看.html文件发现, 优化程度不高, 不过最关键的数值计算代码优化了, 所以性能应该还不错吧.
粗测了一下计算效率,大概能达到10倍到12倍,性能还不错.
后面把代码上传到超算上的时候,发现了一个问题,就是linux系统里面编译生成的不是.pyd文件,还是.so文件,因此需要重新运行setup.py进行编译一次.
后面具体的计算结果等会回来补发上来.

浙公网安备 33010602011771号