查找三个阵列中最接近的三个x,y点

在Python中,我有三个包含x和y坐标的列表。每个列表包含128个点。我怎样才能以有效的方式找到最接近的三点?查找三个阵列中最接近的三个x,y点

这是我的工作Python代码,但它是没有效率不够:

def findclosest(c1, c2, c3): 

mina = 999999999

for i in c1:

for j in c2:

for k in c3:

# calculate sum of distances between points

d = xy3dist(i,j,k)

if d < mina:

mina = d

def xy3dist(a, b, c):

l1 = math.sqrt((a[0]-b[0]) ** 2 + (a[1]-b[1]) ** 2)

l2 = math.sqrt((b[0]-c[0]) ** 2 + (b[1]-c[1]) ** 2)

l3 = math.sqrt((a[0]-c[0]) ** 2 + (a[1]-c[1]) ** 2)

return l1+l2+l3

任何想法如何这可以使用numpy的做什么?

回答:

您可以使用NumPy的广播功能,以矢量化两个内环:

import numpy as np

def findclosest(c1, c2, c3):

c1 = np.asarray(c1)

c2 = np.asarray(c2)

c3 = np.asarray(c3)

for arr in (c1, c2, c3):

if not (arr.ndim == 2 and arr.shape[1] == 2):

raise ValueError("expected arrays of 2D coordinates")

min_val = np.inf

min_pos = None

for a, i in enumerate(c1):

d = xy3dist(i, c2.T[:,:,np.newaxis], c3.T[:,np.newaxis,:])

k = np.argmin(d)

if d.flat[k] < min_val:

min_val = d.flat[k]

b, c = np.unravel_index(k, d.shape)

min_pos = (a, b, c)

print a, min_val, d.min()

return min_val, min_pos

def xy3dist(a, b, c):

l1 = np.sqrt((a[0]-b[0]) ** 2 + (a[1]-b[1]) ** 2)

l2 = np.sqrt((b[0]-c[0]) ** 2 + (b[1]-c[1]) ** 2)

l3 = np.sqrt((a[0]-c[0]) ** 2 + (a[1]-c[1]) ** 2)

return l1+l2+l3

np.random.seed(1234)

c1 = np.random.rand(5, 2)

c2 = np.random.rand(9, 2)

c3 = np.random.rand(7, 2)

val, pos = findclosest(c1, c2, c3)

a, b, c = pos

print val, xy3dist(c1[a], c2[b], c3[c])

也有可能向量化所有的3环

def findclosest2(c1, c2, c3):

c1 = np.asarray(c1)

c2 = np.asarray(c2)

c3 = np.asarray(c3)

d = xy3dist(c1.T[:,:,np.newaxis,np.newaxis], c2.T[:,np.newaxis,:,np.newaxis], c3.T[:,np.newaxis,np.newaxis,:])

k = np.argmin(d)

min_val = d.flat[k]

a, b, c = np.unravel_index(k, d.shape)

min_pos = (a, b, c)

return min_val, min_pos

If your arrays are very big, findclosest可能优于findclosest2,因为它使用较少的内存。 (如果你的数组是巨大的,仅矢量化的一个最里面的循环。)

您可以谷歌“numpy的广播”,以了解更多什么np.newaxis确实

回答:

让我们尝试一些时间不同的解决方案看。

我打算用numpy的随机函数初始化三个数组。如果您有现成的变量是元组列表或列表列表,请在其上调用np.array

import numpy as np 

c1 = np.random.normal(size=(128, 2))

c2 = np.random.normal(size=(128, 2))

c3 = np.random.normal(size=(128, 2))


首先让我们来一次你的代码,所以我们有一个起点。这可能是有益的

def findclosest(c1, c2, c3): 

mina = 999999999

for i in c1:

for j in c2:

for k in c3:

# calculate sum of distances between points

d = xy3dist(i,j,k)

if d < mina:

mina = d

return mina

def xy3dist(a, b, c):

l1 = math.sqrt((a[0]-b[0]) ** 2 + (a[1]-b[1]) ** 2)

l2 = math.sqrt((b[0]-c[0]) ** 2 + (b[1]-c[1]) ** 2)

l3 = math.sqrt((a[0]-c[0]) ** 2 + (a[1]-c[1]) ** 2)

return l1+l2+l3

%timeit findclosest(c1, c2, c3)

# 1 loops, best of 3: 23.3 s per loop


一个功能是scipy.spatial.distance.cdist,其计算分两个阵列之间的所有成对距离。因此,我们可以使用它来预先计算并存储所有距离,然后只需从这些数组中获取并添加距离即可。我也将使用itertools.product来简化循环,尽管它不会做任何加速工作。

from scipy.spatial.distance import cdist 

from itertools import product

def findclosest_usingcdist(c1, c2, c3):

dists_12 = cdist(c1, c2)

dists_23 = cdist(c2, c3)

dists_13 = cdist(c1, c3)

min_dist = np.inf

ind_gen = product(range(len(c1)), range(len(c2)), range(len(c3)))

for i1, i2, i3 in ind_gen:

dist = dists_12[i1, i2] + dists_23[i2, i3] + dists_13[i1, i3]

if dist < min_dist:

min_dist = dist

min_points = (c1[i1], c2[i2], c3[i3])

return min_dist, min_points

%timeit findclosest_usingcdist(c1, c2, c3)

# 1 loops, best of 3: 2.02 s per loop

因此使用cdist购买我们一个数量级的加速。


然而,这甚至没有比较@ pv的答案。他的一些实现被剥离出来,与以前的解决方案进行了更好的比较(请参阅@pv针对实现返回点的答案)。

def findclosest2(c1, c2, c3): 

d = xy3dist(c1.T[:,:,np.newaxis,np.newaxis],

c2.T[:,np.newaxis,:,np.newaxis],

c3.T[:,np.newaxis,np.newaxis,:])

k = np.argmin(d)

min_val = d.flat[k]

i1, i2, i3 = np.unravel_index(k, d.shape)

min_points = (c1[i1], c2[i2], c3[i3])

return min_val, min_points

def xy3dist(a, b, c):

l1 = np.sqrt((a[0]-b[0]) ** 2 + (a[1]-b[1]) ** 2)

l2 = np.sqrt((b[0]-c[0]) ** 2 + (b[1]-c[1]) ** 2)

l3 = np.sqrt((a[0]-c[0]) ** 2 + (a[1]-c[1]) ** 2)

return l1+l2+l3

%timeit findclosest_usingbroadcasting(c1, c2, c3)

# 100 loops, best of 3: 19.1 ms per loop

所以这是一个巨大的加速,绝对是正确的答案。

以上是 查找三个阵列中最接近的三个x,y点 的全部内容, 来源链接: utcz.com/qa/258390.html

回到顶部