在Python中寻找最佳匹配的块/补丁

3

我希望能够在一个较大的二维数组中,找到以位置(x,y)为中心、大小为WxW的窗口内最接近的NxN块。下面的代码可以工作,但对于我的需求来说非常慢,因为我需要多次运行此操作。有更好的方法吗?这里N=3,W=15,x=15,y=15,(bestx,besty)是最佳匹配块的中心。

import numpy as np

## Generate some test data
CurPatch = np.random.randint(20, size=(3, 3))
Data = np.random.randint(20,size=(30,30))

# Current Location 
x,y = 15,15
# Initialise Best Match
bestcost = 999.0
bestx = 0;besty=0

for Wy in xrange(-7,8):
    for Wx in xrange(-7,8):
            Ywj,Ywi = y+Wy,x+Wx 

            cost = 0.0
            for py in xrange(3):
                for px in xrange(3):
                    cost += abs(Data[Ywj+py-1,Ywi+px-1] - CurPatch[py,px]) 

            if cost < bestcost:
                bestcost = cost
                besty,bestx = Ywj,Ywi

print besty,bestx

你可以在 for px in xrange(3): 中检查 cost 是否大于或等于 bestcost,如果是,你可以使用 break 来跳出循环,这样可以节省很多不必要的迭代。 - Kobi K
3
这个问题似乎不适合本论坛,因为它涉及到改进现有的工作代码。您可以考虑将此问题发布在codereview.stackexchange.com上。 - Tim
2个回答

0

正如我在评论中所说,您可以在for px in xrange(3):内部检查cost是否大于或等于bestcost,如果是,则可以使用break语句跳出循环,这样可以节省许多不必要的迭代。

示例(将轻微更改以强调更大的迭代差异):

import numpy as np
import time

## Generate some test data
CurPatch = np.random.randint(100, size=(3, 3))
Data = np.random.randint(100, size=(3000,3000))

# Current Location 
x,y = 10, 10
# Initialise Best Match
bestcost = 999.0
bestx = 0;besty=0

t0 = time.time()
for Wy in xrange(-7,50):
    for Wx in xrange(-7,50):
            Ywj, Ywi = y+Wy, x+Wx

            cost = 0.0
            for py in xrange(3):
                for px in xrange(3):
                    cost += abs(Data[Ywj+py-1,Ywi+px-1] - CurPatch[py,px])
                    if cost >= bestcost:
                        break

            if cost < bestcost:
                bestcost = cost
                besty,bestx = Ywj,Ywi

print besty, bestx
print "time: {}".format(time.time() - t0)

执行时间为26毫秒

时间:0.0269999504089

如果不加break,你的代码将输出37毫秒:

时间:0.0379998683929

我还建议将此代码转换为函数。


0
为了感受速度,在子问题中,当w与你的大窗口大小相同时,使用numpy会更快(而且更简洁)。
a= '''import numpy as np


## Generate some test data
CurPatch = np.random.randint(20, size=(3, 3))
Data = np.random.randint(20,size=(30,30))


def best(CurPatch,Data):

    # Current Location 
    x,y = 15,15
    # Initialise Best Match
    bestcost = 999.0
    bestx = 0;besty=0

    for Wy in xrange(-14,14):
        for Wx in xrange(-14,14):
                Ywj,Ywi = y+Wy,x+Wx 

                cost = 0.0
                for py in xrange(3):
                    for px in xrange(3):
                        cost += (Data[Ywj+py-1,Ywi+px-1] - CurPatch[py,px])**2 

                if cost < bestcost:
                    bestcost = cost
                    besty,bestx = Ywj,Ywi
    return besty,bestx,bestcost



def minimize(CurPatch,W):
    max_sum=999
    s= CurPatch.shape[0]
    S= W.shape[0]
    for i in range(0,S-s):
        for j in range(0,S-s):
            running= np.sum(np.square((W[i:i+3,j:j+3]-CurPatch)))
            if running<max_sum:
                max_sum=running
                x=i+1;y=j+1
    return x,y,max_sum

'''


import timeit
print min(timeit.Timer('minimize(CurPatch,Data)', a).repeat(7, 10))
print min(timeit.Timer('best(CurPatch,Data)', a).repeat(7, 10))     

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接