Python中与MATLAB的“ismember”函数等效的函数

20
在多次尝试优化代码后,似乎最后一个资源是尝试使用多个核心运行下面的代码。我不知道如何转换/重构我的代码,以便在使用多个核心时可以运行得更快。如果我能得到指导来实现最终目标,我将不胜感激。最终目标是能够尽可能快地运行数组A和B,其中每个数组包含约700,000个元素。这里是使用小数组的代码。 700k元素数组已被注释掉。
import numpy as np

def ismember(a,b):
    for i in a:
        index = np.where(b==i)[0]
        if index.size == 0:
            yield 0
        else:
            yield index


def f(A, gen_obj):
    my_array = np.arange(len(A))
    for i in my_array:
        my_array[i] = gen_obj.next()
    return my_array


#A = np.arange(700000)
#B = np.arange(700000)
A = np.array([3,4,4,3,6])
B = np.array([2,5,2,6,3])

gen_obj = ismember(A,B)

f(A, gen_obj)

print 'done'
# if we print f(A, gen_obj) the output will be: [4 0 0 4 3]
# notice that the output array needs to be kept the same size as array A.
我想做的是模仿一个名为ismember[2]的 MATLAB 函数(格式为:[Lia,Locb]=ismember(A,B)),我只需要得到Locb部分。
从Matlab中可以看出:Locb包含A中每个值在B中的最低索引。输出数组Locb在A不属于B的地方包含0。
主要问题之一是我需要尽可能高效地执行此操作。测试时,我有两个包含70万个元素的数组。创建生成器并遍历生成器的值似乎无法快速完成任务。
5个回答

18

在考虑多核处理之前,我会通过使用字典来删除你的ismember函数中的线性扫描:

def ismember(a, b):
    bind = {}
    for i, elt in enumerate(b):
        if elt not in bind:
            bind[elt] = i
    return [bind.get(itm, None) for itm in a]  # None can be replaced by any other "not in b" value

您原来的实现需要对B中的每个元素进行完整扫描,使其时间复杂度为O(len(A)*len(B))。上述代码只需要对B进行一次完整扫描以生成字典Bset。使用字典,您可以有效地使得对于A中的每个元素,查找B中的每个元素的操作都是常数级别的,因此操作的时间复杂度为O(len(A)+len(B))。如果这仍然太慢,那么请考虑让上述函数在多个核心上运行。

编辑:我还稍微修改了您的索引。Matlab使用0,因为它的所有数组都从索引1开始。Python/numpy将数组从0开始,因此如果您的数据集如下所示:

A = [2378, 2378, 2378, 2378]
B = [2378, 2379]

如果你返回0代表没有元素,那么你的结果将会排除A数组中的所有元素。上面的程序返回None代表没有索引而不是0。 返回-1也是一种选择,但Python会将其解释为数组中最后一个元素。如果None用作数组的索引,它将引发异常。如果您想要不同的行为,请将Bind.get(item, None)表达式中的第二个参数更改为您想要返回的值。

哇,这真的非常快!你不知道我有多么欣赏你的解决方案。非常感谢!你使用特定的工具输出性能分析报告吗? - zd5151
7
不,这是一种直接的算法分析。使用大O符号np.where需要对B进行线性扫描,这需要O(len(B))个操作。 然后使用一个外循环,需要O(len(A))个操作,使得您的原始算法大约需要O(len(A)*len(B))个操作。 生成Bind需要len(B)个操作。 字典实现为hash表,具有常数O(1)的查找,因此扫描A是O(len(A));总体复杂度为O(len(A)+len(B)) - sfstewman
明白了。感谢您提供维基百科的参考。 - zd5151
1
不,你把代码搞砸了。现在返回的元素是列表中的最后一个出现的元素,而不是第一个。我在原始代码中没有使用字典推导式是有原因的。 - sfstewman
1
据我所知,您可以通过迭代反向范围来使用字典推导式:{ B[i] : i for i in xrange(len(B)-1,-1,-1) }。或者使用反向迭代器:{ elt : len(B)-i-1) for (i,elt) in enumerate(reversed(B)) }。两种方法都不太美观(或简单)。第一种假设B是可索引的,第二种假设B是可逆的。它们还假设随机索引/反向迭代是廉价的。如果B是非常大的链表,则性能将非常糟糕。是否有一种使用推导式仅假定迭代以检索第一个索引的方法? - sfstewman
显示剩余6条评论

15

sfstewman的出色回答很可能为您解决了问题。

我想补充说明如何仅使用numpy实现相同的功能。

我利用了numpy的uniquein1d函数。

B_unique_sorted, B_idx = np.unique(B, return_index=True)
B_in_A_bool = np.in1d(B_unique_sorted, A, assume_unique=True)
  • B_unique_sorted 包含了已排序的 B 中的唯一值。
  • B_idx 保存了这些值在原来的 B 中的索引。
  • B_in_A_bool 是一个大小与 B_unique_sorted 相同的布尔数组,用于存储 B_unique_sorted 中的值是否在 A 中。
    注意:我需要在 A 中查找 (B 的唯一值),因为我需要根据 B_idx 返回输出。
    注意:我假设 A 已经是唯一的。

现在您可以使用 B_in_A_bool 来获取共同的值。

B_unique_sorted[B_in_A_bool]

以及它们在原始的B中各自对应的索引。

B_idx[B_in_A_bool]

最后,我认为这比纯Python for循环快得多,尽管我没有测试过。


尽可能在代码中使用numpy,这样可以实现更快的速度提升(正如我艰难地学到的 >_<)。 - James Porter
2
小心!这不保留索引的顺序!用range(1,6)和[5,1]试试。如果不需要索引的顺序,我认为你可以直接使用np.in1d(),然后np.nonzero()[0]。 - aless80
1
请查看此处的答案:https://dev59.com/h1wX5IYBdhLWcg3wjgCf,以获取正确顺序的索引。 - aless80

2
尝试使用ismember库。
pip install ismember

简单的例子:

# Import library
from ismember import ismember
import numpy as np

# data
A = np.array([3,4,4,3,6])
B = np.array([2,5,2,6,3])

# Lookup
Iloc,idx = ismember(A, B)
 
# Iloc is boolean defining existence of d in d_unique
print(Iloc)
# [ True False False  True  True]

# indexes of d_unique that exists in d
print(idx)
# [4 4 3]

print(B[idx])
# [3 3 6]

print(A[Iloc])
# [3 3 6]

# These vectors will match
A[Iloc]==B[idx]

速度检测:

from ismember import ismember
from datetime import datetime

t1=[]
t2=[]
# Create some random vectors
ns = np.random.randint(10,10000,1000)

for n in ns:
    a_vec = np.random.randint(0,100,n)
    b_vec = np.random.randint(0,100,n)

    # Run stack version
    start = datetime.now()
    out1=ismember_stack(a_vec, b_vec)
    end = datetime.now()
    t1.append(end - start)

    # Run ismember
    start = datetime.now()
    out2=ismember(a_vec, b_vec)
    end = datetime.now()
    t2.append(end - start)


print(np.sum(t1))
# 0:00:07.778331

print(np.sum(t2))
# 0:00:04.609801

# %%
def ismember_stack(a, b):
    bind = {}
    for i, elt in enumerate(b):
        if elt not in bind:
            bind[elt] = i
    return [bind.get(itm, None) for itm in a]  # None can be replaced by any other "not in b" value

ismember函数来自pypi,速度几乎快了2倍。

大向量,例如700000个元素:

from ismember import ismember
from datetime import datetime

A = np.random.randint(0,100,700000)
B = np.random.randint(0,100,700000)

# Lookup
start = datetime.now()
Iloc,idx = ismember(A, B)
end = datetime.now()

# Print time
print(end-start)
# 0:00:01.194801

你认为这个解决方案怎么样? - zd5151

1
这是确切的MATLAB等效代码,返回与MATLAB匹配的输出参数[Lia, Locb],但在Python中0也是有效索引。因此,此函数不返回0。它实际上返回Locb(Locb> 0)。性能也与MATLAB相当。
def ismember(a_vec, b_vec):
    """ MATLAB equivalent ismember function """

    bool_ind = np.isin(a_vec,b_vec)
    common = a[bool_ind]
    common_unique, common_inv  = np.unique(common, return_inverse=True)     # common = common_unique[common_inv]
    b_unique, b_ind = np.unique(b_vec, return_index=True)  # b_unique = b_vec[b_ind]
    common_ind = b_ind[np.isin(b_unique, common_unique, assume_unique=True)]
    return bool_ind, common_ind[common_inv]

这里有一个备用实现,虽然速度较慢(约为原来的5倍),但不使用unique函数:
def ismember(a_vec, b_vec):
    ''' MATLAB equivalent ismember function. Slower than above implementation'''
    b_dict = {b_vec[i]: i for i in range(0, len(b_vec))}
    indices = [b_dict.get(x) for x in a_vec if b_dict.get(x) is not None]
    booleans = np.in1d(a_vec, b_vec)
    return booleans, np.array(indices, dtype=int)

1
尝试使用列表推导式;
In [1]: import numpy as np

In [2]: A = np.array([3,4,4,3,6])

In [3]: B = np.array([2,5,2,6,3])

In [4]: [x for x in A if not x in B]
Out[4]: [4, 4]

通常情况下,列表推导式比for循环更快。

要获取等长度的列表;

In [19]: map(lambda x: x if x not in B else False, A)
Out[19]: [False, 4, 4, False, False]

对于小数据集来说,这非常快:

In [20]: C = np.arange(10000)

In [21]: D = np.arange(15000, 25000)

In [22]: %timeit map(lambda x: x if x not in D else False, C)
1 loops, best of 3: 756 ms per loop

对于大型数据集,您可以尝试使用 multiprocessing.Pool.map() 来加速操作。


输出数组需要保持相同的大小。 - zd5151
@z5151:请查看增强版答案。如果您愿意,可以将lambda表达式更改为返回0而不是False,但这会掩盖结果中的真实0。 - Roland Smith
这对于元素数量较少的数组非常有用。感谢您强调列表推导比循环快得多。 - zd5151
你的答案返回元素,而不是B中元素的索引。 - sfstewman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接