流れに従ひて己を失はず.

日々の研究について書きます

SymPyでの幾何計算速度

pythonのライブラリであるSymPyを使って幾何計算をやるととても遅いという話.

以下は2直線の交点導出を100セット解いた例.
SymPyを使った場合とベクトルで解いた場合の比較.10^3オーダーで変わってくる.

SymPyを使った場合

import numpy as np
import sympy.geometry as sg
import time

#点[listA]を通り方向ベクトル[listB]を持つ直線と
#点[listC]を通り方向ベクトル[listD]を持つ直線との交点を求める
listA = np.random.rand(100,2)
listB = np.random.rand(100,2)
listC = np.random.rand(100,2)
listD = np.random.rand(100,2)

start = time.time()

for i in range(0,100):
    pointA = listA[i,:]
    direcB = listB[i,:]
    pointC = listC[i,:]
    direcD = listD[i,:]
    line1 = sg.Line( sg.Point(pointA), 
                                        sg.Point(pointA+direcB))
    line2 = sg.Line( sg.Point(pointC), 
                                        sg.Point(pointC+direcD))
    result = sg.intersection(line1, line2)

elapsed_time = time.time() - start
print ("sympy calc_time:{0}".format(round(elapsed_time,3)) + "[sec]")

sympy calc_time:15.911[sec]


ベクトルで解いた場合

import numpy as np
import sympy.geometry as sg
import time

#点[listA]を通り方向ベクトル[listB]を持つ直線と
#点[listC]を通り方向ベクトル[listD]を持つ直線との交点を求める
listA = np.random.rand(100,2)
listB = np.random.rand(100,2)
listC = np.random.rand(100,2)
listD = np.random.rand(100,2)

#90度の回転行列
rot90 = np.array([[np.cos(np.pi/2), -np.sin(np.pi/2)],
                              [np.sin(np.pi/2), np.cos(np.pi/2)]])

def calcintersection(point1,direc1,point2,direc2):
    # 参考
    # https://qiita.com/tmakimoto/items/2da05225633272ef935c
    
    #2直線が平行なら[]を返す
    if direc1[0]*direc2[1]==direc1[1]*direc2[0]:
        return []
    
    #方向ベクトルの正規化
    direc1_n = direc1 / np.linalg.norm(direc1) 
    direc2_n = direc2 / np.linalg.norm(direc2) 
    
    #direc2_nに垂直な単位ベクトル
    direc3_n = np.dot(direc2_n, rot90)
    
    #直線1における交点のパラメータloを求める
    lo = np.dot(direc3_n,(point2-point1)) / np.dot(direc3_n,direc1_n)
    
    return point1+direc1_n*lo


start = time.time()

for i in range(0,100):
    pointA = listA[i,:]
    direcB = listB[i,:]
    pointC = listC[i,:]
    direcD = listD[i,:]
    
    result = calcintersection(pointA,direcB,pointC,direcD)

elapsed_time = time.time() - start
print ("vector calc_time:{0}".format(round(elapsed_time,3)) + "[sec]")

vector calc_time:0.004[sec]