将索引列表转换为二维numpy数组中值为1的最快方法

7

我有一个索引列表

a = [
  [1,2,4],
  [0,2,3],
  [1,3,4],
  [0,2]]

什么是将此内容转换为numpy数组的最快方法,使每个索引都显示1出现的位置?即我想要的是:
output = array([
  [0,1,1,0,1],
  [1,0,1,1,0],
  [0,1,0,1,1],
  [1,0,1,0,0]])

我事先知道数组的最大尺寸。我知道我可以循环遍历每个列表,并在每个索引位置插入1,但是否有更快/向量化的方法来做这件事?

我的用例可能有数千行/列,我需要重复执行数千次,所以速度越快越好。


1
因为a是一个不规则列表,所以它不容易进行向量化处理。 - cs95
1
我想可能没有好的方法,但只是想看看stackoverflow的大脑信托能想出什么 :) - Hansang
6个回答

10

这个怎么样:

ncol = 5
nrow = len(a)
out = np.zeros((nrow, ncol), int)
out[np.arange(nrow).repeat([*map(len,a)]), np.concatenate(a)] = 1
out
# array([[0, 1, 1, 0, 1],
#        [1, 0, 1, 1, 0],
#        [0, 1, 0, 1, 1],
#        [1, 0, 1, 0, 0]])

这里是一个1000x1000的二进制数组的时间表,注意我使用了上面优化过的版本,参见下面的pp函数:

pp 21.717635259992676 ms
ts 37.10938713003998 ms
u9 37.32933565042913 ms

生成时间的代码:

import itertools as it
import numpy as np

def make_data(n,m):
    I,J = np.where(np.random.random((n,m))<np.random.random((n,1)))
    return [*map(np.ndarray.tolist, np.split(J, I.searchsorted(np.arange(1,n))))]

def pp():
    sz = np.fromiter(map(len,a),int,nrow)
    out = np.zeros((nrow,ncol),int)
    out[np.arange(nrow).repeat(sz),np.fromiter(it.chain.from_iterable(a),int,sz.sum())] = 1
    return out

def ts():
    out = np.zeros((nrow,ncol),int)
    for i, ix in enumerate(a):
        out[i][ix] = 1
    return out

def u9():
    out = np.zeros((nrow,ncol),int)
    for i, (x, y) in enumerate(zip(a, out)):
        y[x] = 1
        out[i] = y
    return out

nrow,ncol = 1000,1000
a = make_data(nrow,ncol)

from timeit import timeit
assert (pp()==ts()).all()
assert (pp()==u9()).all()

print("pp", timeit(pp,number=100)*10, "ms")
print("ts", timeit(ts,number=100)*10, "ms")
print("u9", timeit(u9,number=100)*10, "ms")

2
从外观上看,使用多个 numpy 函数和 map 也会更慢(当然没有尝试过就无法确认)。 - Teshan Shanuka J
2
@TeshanShanukaJ,你是在暗示你的解决方案更快吗?你有任何时间测试来支持它吗?性能取决于数据,而且我认为这个解决方案会很好地扩展(这也是我为什么点赞的原因)。 - cs95
1
我不知道。我只是发出一个警告,因为OP正在寻找最快的解决方案。我已经提到我的解决方案也不会是最快的。这取决于OP来测试时间。 - Teshan Shanuka J
1
@TeshanShanukaJ 实际上,在中等大小的例子(比如1000x1000)上,你的速度似乎更快一些(约10%)。 - Paul Panzer
1
@TeshanShanukaJ 经过一些调整,我现在快了约40%。 - Paul Panzer

6
这可能不是最快的方法。您需要使用大型数组比较这些答案的执行时间,以找到最快的方法。以下是我的解决方案。
output = np.zeros((4,5))
for i, ix in enumerate(a):
    output[i][ix] = 1

# output -> 
#   array([[0, 1, 1, 0, 1],
#   [1, 0, 1, 1, 0],
#   [0, 1, 0, 1, 1],
#   [1, 0, 1, 0, 0]])

1
如果提供实际的时间信息,答案将会更加准确。 - aaaaa says reinstate Monica

4
如果你能使用Cython并且愿意的话,你可以创建一个易读(至少如果你不介意输入)和快速的解决方案。
在这里,我正在使用Cython的IPython绑定将其编译为Jupyter笔记本:
%load_ext cython

%%cython

cimport cython
cimport numpy as cnp
import numpy as np

@cython.boundscheck(False)  # remove this if you cannot guarantee that nrow/ncol are correct
@cython.wraparound(False)
cpdef cnp.int_t[:, :] mseifert(list a, int nrow, int ncol):
    cdef cnp.int_t[:, :] out = np.zeros([nrow, ncol], dtype=int)
    cdef list subl
    cdef int row_idx
    cdef int col_idx
    for row_idx, subl in enumerate(a):
        for col_idx in subl:
            out[row_idx, col_idx] = 1
    return out

为了比较这里介绍的解决方案的性能,我使用了我的库simple_benchmark

enter image description here

请注意,这里使用对数轴同时显示小和大数组的差异。根据我的基准测试,我的函数实际上是解决方案中最快的,然而也值得指出的是,所有解决方案的差距并不太大。
以下是我用于基准测试的完整代码:
import numpy as np
from simple_benchmark import BenchmarkBuilder, MultiArgument
import itertools

b = BenchmarkBuilder()

@b.add_function()
def pp(a, nrow, ncol):
    sz = np.fromiter(map(len, a), int, nrow)
    out = np.zeros((nrow, ncol), int)
    out[np.arange(nrow).repeat(sz), np.fromiter(itertools.chain.from_iterable(a), int, sz.sum())] = 1
    return out

@b.add_function()
def ts(a, nrow, ncol):
    out = np.zeros((nrow, ncol), int)
    for i, ix in enumerate(a):
        out[i][ix] = 1
    return out

@b.add_function()
def u9(a, nrow, ncol):
    out = np.zeros((nrow, ncol), int)
    for i, (x, y) in enumerate(zip(a, out)):
        y[x] = 1
        out[i] = y
    return out

b.add_functions([mseifert])

@b.add_arguments("number of rows/columns")
def argument_provider():
    for n in range(2, 13):
        ncols = 2**n
        a = [
            sorted(set(np.random.randint(0, ncols, size=np.random.randint(0, ncols)))) 
            for _ in range(ncols)
        ]
        yield ncols, MultiArgument([a, ncols, ncols])

r = b.run()
r.plot()

实际上,我很惊讶Cython在这里的普及度如此之低,考虑到输入格式对于numpy来说有些不太方便。 - Paul Panzer
@PaulPanzer 我也有点惊讶 - 我认为唯一相关(关于性能)的部分是迭代列表的列表。在你的情况下,它是 itertools.chain.from_iterable,而在我的情况下是显式迭代。其他所有内容基本上都是恒定的开销。 - MSeifert

3
可能不是最好的方法,但这是我能想到的唯一方式:
output = np.zeros((4,5))
for i, (x, y) in enumerate(zip(a, output)):
    y[x] = 1
    output[i] = y
print(output)

将输出:

[[ 0.  1.  1.  0.  1.]
 [ 1.  0.  1.  1.  0.]
 [ 0.  1.  0.  1.  1.]
 [ 1.  0.  1.  0.  0.]]

2
这非常整洁(比我的尝试漂亮多了),尽管在运行时间方面,它看起来与手动编写循环相同? - Hansang
1
@Spcoggthesecond 然后使用Paul的解决方案。 - U13-Forward

1
如何使用数组索引?如果您对输入有更多了解,可以消除首先转换为线性数组的惩罚。
import numpy as np


def main():
    row_count = 4
    col_count = 5
    a = [[1,2,4],[0,2,3],[1,3,4],[0,2]]

    # iterate through each row, concatenate all indices and convert them to linear

    # numpy append performs copy even if you don't want it, list append is faster
    b = []
    for row_idx, row in enumerate(a):
        b.append(np.array(row, dtype=np.int64) + (row_idx * col_count))

    linear_idxs = np.hstack(b)
    #could skip previous steps if given index inputs well before hand, or in linear index order. 
    c = np.zeros(row_count * col_count)
    c[linear_idxs] = 1
    c = c.reshape(row_count, col_count)
    print(c)


if __name__ == "__main__":
    main()

#output
# [[0. 1. 1. 0. 1.]
#  [1. 0. 1. 1. 0.]
#  [0. 1. 0. 1. 1.]
#  [1. 0. 1. 0. 0.]]

1

根据您的使用情况,您可能需要考虑使用稀疏矩阵。输入矩阵看起来很像压缩稀疏行(CSR)矩阵。也许可以尝试使用该类型矩阵。

import numpy as np
from scipy.sparse import csr_matrix
from itertools import accumulate


def ragged2csr(inds):
    offset = len(inds[0])
    lens = [len(x) for x in inds]
    indptr = list(accumulate(lens))
    indptr = np.array([x - offset for x in indptr])
    indices = np.array([val for sublist in inds for val in sublist])
    n = indices.size
    data = np.ones(n)
    return csr_matrix((data, indices, indptr))

如果符合您的使用情况,稀疏矩阵将允许元素级/掩码操作随着非零元素的数量而扩展,而不是元素的数量(行*列),这可能会带来显着的加速(对于足够稀疏的矩阵)。

CSR矩阵的另一个很好的介绍是迭代方法的第3.4节。在这种情况下,dataaaindicesjaindptria。这种格式还具有在不同的软件包/库之间非常流行的好处。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接