我试图在一个numpy数组中查找重复的行。以下代码复制了我的数组结构,该数组具有 n 行、m 列和每行 nz 个非零条目:
import numpy as np
import random
import datetime
def create_mat(n, m, nz):
sample_mat = np.zeros((n, m), dtype='uint8')
random.seed(42)
for row in range(0, n):
counter = 0
while counter < nz:
random_col = random.randrange(0, m-1, 1)
if sample_mat[row, random_col] == 0:
sample_mat[row, random_col] = 1
counter += 1
test = np.all(np.sum(sample_mat, axis=1) == nz)
print(f'All rows have {nz} elements: {test}')
return sample_mat
我尝试优化的代码如下:
if __name__ == '__main__':
threshold = 2
mat = create_mat(1800000, 108, 8)
print(f'Time: {datetime.datetime.now()}')
unique_rows, _, duplicate_counts = np.unique(mat, axis=0, return_counts=True, return_index=True)
duplicate_indices = [int(x) for x in np.argwhere(duplicate_counts >= threshold)]
print(f'Time: {datetime.datetime.now()}')
print(f'Unique rows: {len(unique_rows)} Sample inds: {duplicate_indices[0:5]} Sample counts: {duplicate_counts[0:5]}')
print(f'Sample rows:')
print(unique_rows[0:5])
我的输出如下:
All rows have 8 elements: True
Time: 2022-06-29 12:08:07.320834
Time: 2022-06-29 12:08:23.281633
Unique rows: 1799994 Sample inds: [508991, 553136, 930379, 1128637, 1290356] Sample counts: [1 1 1 1 1]
Sample rows:
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 0 0 1 1 0 0 0 1 0 0 0 0 0 0 1 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 0 1 0]]
我考虑使用numba,但挑战在于它不使用轴参数。同样,将其转换为列表并利用集合是一个选择,但是循环执行重复计数似乎“不符合Python风格”。
由于我需要多次运行此代码(因为我正在修改numpy数组,然后需要重新搜索重复项),所以时间很关键。我还尝试使用多进程处理该步骤,但np.unique似乎是阻塞的(即使我尝试运行多个版本,我仍然被限制在一个线程上,CPU利用率为6%,而其他线程处于空闲状态)。
np.argwhere(unique_counts >= threshold).ravel()
。 - Michael Szczesnynp.unique
是单线程的。然而,加速比仅约为40%。 - Michael Szczesnyunique
函数在给定axis
参数时,会创建一个结构化数组视图(view
),将每一行封装为一个记录(record
)。然后对这个数组进行unique1d
操作,其中排序是非常重要的。实际上,它会将“行”(作为不可变单元)进行排序,并删除相邻的重复项。 - hpauljnp.unique
。但是,这仅在列数是无符号整数类型的倍数时才有效。笔记本 - Michael Szczesnynumba
的实现可以工作,但在这个例子中比np.unique
慢了15%。 - Michael Szczesny