有没有办法使用numpy函数来像归并排序中的合并一样进行合并操作?
类似于 merge 的某些函数:
a = np.array([1,3,5])
b = np.array([2,4,6])
c = merge(a, b) # c == np.array([1,2,3,4,5,6])
我希望能够通过使用numpy获得大数据的高性能
有没有办法使用numpy函数来像归并排序中的合并一样进行合并操作?
类似于 merge 的某些函数:
a = np.array([1,3,5])
b = np.array([2,4,6])
c = merge(a, b) # c == np.array([1,2,3,4,5,6])
我希望能够通过使用numpy获得大数据的高性能
你可以使用
from numpy import concatenate, sort
c = concatenate((a,b))
c.sort(kind='mergesort')
除非您编写自己的排序函数作为Python扩展,例如使用 cython
,否则恐怕您无法做得比这更好。
有类似问题可以看看这个链接。不过那里是将合并后的数组中仅保留唯一值。那里的基准测试和评论也很有见地。
sortednp包实现了一种高效的合并已排序的numpy数组的方法:
import numpy as np
import sortednp
a = np.array([1,3,5])
b = np.array([2,4,6])
c = sortednp.merge(a, b) # c == np.array([1,2,3,4,5,6])
受Sander的文章启发,我使用以下代码测试了不同数组大小和a与b之间大小比率下numpy的mergesort(v1.17.4)、Sander的答案和sortednp(v0.2.1)的性能:
from timeit import timeit as t
import sortednp as snp
import numpy as np
def numpy_mergesort(a, b):
c = np.concatenate((a,b))
c.sort(kind='mergesort')
return c
def sanders_merge(a, b):
if len(a) < len(b):
b, a = a, b
c = np.empty(len(a) + len(b), dtype=a.dtype)
b_indices = np.arange(len(b)) + np.searchsorted(a, b)
a_indices = np.ones(len(c), dtype=bool)
a_indices[b_indices] = False
c[b_indices] = b
c[a_indices] = a
return c
results = []
for size_factor in range(3):
for max_digits in range(3, 8):
size = 10**max_digits
# size difference of a factor 10 here makes the difference!
a = np.arange(size // 10**size_factor, dtype=np.int)
b = np.arange(size, dtype=np.int)
assert np.array_equal(numpy_mergesort(a, b), sanders_merge(a, b))
assert np.array_equal(numpy_mergesort(a, b), snp.merge(a, b))
classic = t(lambda: numpy_mergesort(a, b), number=10)
sanders = t(lambda: sanders_merge(a, b), number=10)
snp_result = t(lambda: snp.merge(a, b), number=10)
results.append((size_factor, max_digits, classic, sanders, snp_result))
text_format = " ".join(["{:<18}"] * 5)
print(text_format.format("log10(size factor)", "log10(max size)", "np mergesort", "Sander's merge", "sortednp"))
table_format = " ".join(["{:.5f}"] * 5)
for result in results:
print(table_format.format(*result))
log10(size factor) log10(max size) np mergesort Sander's merge sortednp
0.00000 3.00000 0.00016 0.00062 0.00005
0.00000 4.00000 0.00135 0.00469 0.00029
0.00000 5.00000 0.01160 0.03813 0.00292
0.00000 6.00000 0.14952 0.54160 0.03527
0.00000 7.00000 2.00566 5.91691 0.67119
1.00000 3.00000 0.00005 0.00017 0.00002
1.00000 4.00000 0.00019 0.00058 0.00014
1.00000 5.00000 0.00304 0.00633 0.00137
1.00000 6.00000 0.03743 0.06893 0.01828
1.00000 7.00000 0.62334 1.01523 0.38732
2.00000 3.00000 0.00004 0.00015 0.00002
2.00000 4.00000 0.00012 0.00028 0.00013
2.00000 5.00000 0.00217 0.00275 0.00122
2.00000 6.00000 0.03457 0.03205 0.01524
2.00000 7.00000 0.51307 0.50120 0.34335
import numpy as np
def classic_merge(a, b):
c = np.concatenate((a,b))
c.sort(kind='mergesort')
return c
def new_merge(a, b):
if len(a) < len(b):
b, a = a, b
c = np.empty(len(a) + len(b), dtype=a.dtype)
b_indices = np.arange(len(b)) + np.searchsorted(a, b)
a_indices = np.ones(len(c), dtype=bool)
a_indices[b_indices] = False
c[b_indices] = b
c[a_indices] = a
return c
定时给出结果:
from timeit import timeit as t
results = []
for size_digits in range(2, 8):
size = 10**size_digits
# size difference of a factor 10 here makes the difference!
a = np.arange(size // 10, dtype=np.int)
b = np.arange(size, dtype=np.int)
classic = t(lambda: classic_merge(a, b), number=10)
new = t(lambda: new_merge(a, b), number=10)
results.append((size_digits, classic, new))
if True:
text_format = " ".join(["{:<15}"] * 3)
print(text_format.format("log10(size)", "Classic", "New"))
table_format = " ".join(["{:.5f}"] * 3)
for result in results:
print(table_format.format(*result))
log10(size) Classic New
2.00000 0.00009 0.00027
3.00000 0.00021 0.00030
4.00000 0.00233 0.00082
5.00000 0.02827 0.00601
6.00000 0.33322 0.06059
7.00000 4.40571 0.86764
from timeit import timeit as t
results = []
for size_digits in range(2, 8):
size = 10**size_digits
# same size
a = np.arange(size , dtype=np.int)
b = np.arange(size, dtype=np.int)
classic = t(lambda: classic_merge(a, b), number=10)
new = t(lambda: new_merge(a, b), number=10)
results.append((size_digits, classic, new))
if True:
text_format = " ".join(["{:<15}"] * 3)
print(text_format.format("log10(size)", "Classic", "New"))
table_format = " ".join(["{:.5f}"] * 3)
for result in results:
print(table_format.format(*result))
log10(size) Classic New
2.00000 0.00026 0.00087
3.00000 0.00108 0.00182
4.00000 0.01257 0.01243
5.00000 0.16333 0.12692
6.00000 1.05006 0.49186
7.00000 8.35967 5.93732