将NumPy排序数组合并到新数组中?

20

有没有办法使用numpy函数来像归并排序中的合并一样进行合并操作?

类似于 merge 的某些函数:

a = np.array([1,3,5])
b = np.array([2,4,6])
c = merge(a, b) # c == np.array([1,2,3,4,5,6])

我希望能够通过使用numpy获得大数据的高性能


1
请看这个回答,它提供了一个非常高效的解决方案:https://dev59.com/lWct5IYBdhLWcg3wAYzq#12427633 - Imran
可能是合并两个数组并排序的重复问题。 - Imran
你只需要拼接还是排序? - Dima Tisnek
3个回答

18

你可以使用

from numpy import concatenate, sort

c = concatenate((a,b))
c.sort(kind='mergesort')

除非您编写自己的排序函数作为Python扩展,例如使用 cython ,否则恐怕您无法做得比这更好。

有类似问题可以看看这个链接。不过那里是将合并后的数组中仅保留唯一值。那里的基准测试和评论也很有见地。


2
好吧,可能那是最好的方式。我搜索了很多相关答案,但似乎总是得这样… - darwinsenior
这个解决方案的时间复杂度为O(n log n),而合并两个数组的时间复杂度为O(n)。 - G. Cohen

11

sortednp包实现了一种高效的合并已排序的numpy数组的方法:

import numpy as np
import sortednp
a = np.array([1,3,5])
b = np.array([2,4,6])
c = sortednp.merge(a, b) # c == np.array([1,2,3,4,5,6])

受Sander的文章启发,我使用以下代码测试了不同数组大小和a与b之间大小比率下numpy的mergesort(v1.17.4)、Sander的答案和sortednp(v0.2.1)的性能:

from timeit import timeit as t
import sortednp as snp
import numpy as np

def numpy_mergesort(a, b):
    c = np.concatenate((a,b))
    c.sort(kind='mergesort')
    return c

def sanders_merge(a, b):
    if len(a) < len(b):
        b, a = a, b
    c = np.empty(len(a) + len(b), dtype=a.dtype)
    b_indices = np.arange(len(b)) + np.searchsorted(a, b)
    a_indices = np.ones(len(c), dtype=bool)
    a_indices[b_indices] = False
    c[b_indices] = b
    c[a_indices] = a
    return c

results = []

for size_factor in range(3):
    for max_digits in range(3, 8):
        size = 10**max_digits
        # size difference of a factor 10 here makes the difference!
        a = np.arange(size // 10**size_factor, dtype=np.int)
        b = np.arange(size, dtype=np.int)
        assert np.array_equal(numpy_mergesort(a, b), sanders_merge(a, b))
        assert np.array_equal(numpy_mergesort(a, b), snp.merge(a, b))
        classic = t(lambda: numpy_mergesort(a, b), number=10)
        sanders = t(lambda: sanders_merge(a, b), number=10)
        snp_result = t(lambda: snp.merge(a, b), number=10)
        results.append((size_factor, max_digits, classic, sanders, snp_result))

text_format = " ".join(["{:<18}"] * 5)
print(text_format.format("log10(size factor)", "log10(max size)", "np mergesort", "Sander's merge", "sortednp"))
table_format = "            ".join(["{:.5f}"] * 5)
for result in results:
    print(table_format.format(*result))

结果表明,sortednp一直是最快的实现:
log10(size factor) log10(max size)    np mergesort       Sander's merge     sortednp          
0.00000            3.00000            0.00016            0.00062            0.00005
0.00000            4.00000            0.00135            0.00469            0.00029
0.00000            5.00000            0.01160            0.03813            0.00292
0.00000            6.00000            0.14952            0.54160            0.03527
0.00000            7.00000            2.00566            5.91691            0.67119
1.00000            3.00000            0.00005            0.00017            0.00002
1.00000            4.00000            0.00019            0.00058            0.00014
1.00000            5.00000            0.00304            0.00633            0.00137
1.00000            6.00000            0.03743            0.06893            0.01828
1.00000            7.00000            0.62334            1.01523            0.38732
2.00000            3.00000            0.00004            0.00015            0.00002
2.00000            4.00000            0.00012            0.00028            0.00013
2.00000            5.00000            0.00217            0.00275            0.00122
2.00000            6.00000            0.03457            0.03205            0.01524
2.00000            7.00000            0.51307            0.50120            0.34335

4
当一个数组比另一个数组大很多时,可以通过执行np.searchsorted获得相当快的加速(在我的电脑上为5倍),其速度主要受制于搜索较小数组的插入索引:
import numpy as np

def classic_merge(a, b):
    c = np.concatenate((a,b))
    c.sort(kind='mergesort')
    return c

def new_merge(a, b):
    if len(a) < len(b):
        b, a = a, b
    c = np.empty(len(a) + len(b), dtype=a.dtype)
    b_indices = np.arange(len(b)) + np.searchsorted(a, b)
    a_indices = np.ones(len(c), dtype=bool)
    a_indices[b_indices] = False
    c[b_indices] = b
    c[a_indices] = a
    return c

定时给出结果:

from timeit import timeit as t

results = []
for size_digits in range(2, 8):
    size = 10**size_digits
    # size difference of a factor 10 here makes the difference!
    a = np.arange(size // 10, dtype=np.int)
    b = np.arange(size, dtype=np.int)
    classic = t(lambda: classic_merge(a, b), number=10)
    new = t(lambda: new_merge(a, b), number=10)
    results.append((size_digits, classic, new))

if True:
    text_format = " ".join(["{:<15}"] * 3)
    print(text_format.format("log10(size)", "Classic", "New"))
    table_format = "         ".join(["{:.5f}"] * 3)
    for result in results:
        print(table_format.format(*result))

log10(size)     Classic         New            
2.00000         0.00009         0.00027
3.00000         0.00021         0.00030
4.00000         0.00233         0.00082
5.00000         0.02827         0.00601
6.00000         0.33322         0.06059
7.00000         4.40571         0.86764

当a和b的长度大致相等时,差异较小:
from timeit import timeit as t

results = []
for size_digits in range(2, 8):
    size = 10**size_digits
    # same size
    a = np.arange(size , dtype=np.int)
    b = np.arange(size, dtype=np.int)
    classic = t(lambda: classic_merge(a, b), number=10)
    new = t(lambda: new_merge(a, b), number=10)
    results.append((size_digits, classic, new))

if True:
    text_format = " ".join(["{:<15}"] * 3)
    print(text_format.format("log10(size)", "Classic", "New"))
    table_format = "         ".join(["{:.5f}"] * 3)
    for result in results:
        print(table_format.format(*result))

log10(size)     Classic         New            
2.00000         0.00026         0.00087
3.00000         0.00108         0.00182
4.00000         0.01257         0.01243
5.00000         0.16333         0.12692
6.00000         1.05006         0.49186
7.00000         8.35967         5.93732

感谢您的快速 new_merge 函数。 唯一的评论是:必须在使用之前定义 c。 - jan

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接