Numpy追加:自动转换错误维度的数组

5

有没有不使用if语句来完成以下操作的方法?

我正在使用pupynere读取一组netcdf文件,并希望使用numpy append构建一个数组。有时输入数据是多维的(如下面的变量“a”),有时是一维的(“b”),但第一维中元素的数量始终相同(在下面的示例中为“9”)。

> import numpy as np
> a = np.arange(27).reshape(3,9)
> b = np.arange(9)
> a.shape
(3, 9)
> b.shape
(9,)

这按预期工作:

> np.append(a,a, axis=0)
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
   [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
   [18, 19, 20, 21, 22, 23, 24, 25, 26],
   [ 0,  1,  2,  3,  4,  5,  6,  7,  8],
   [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
   [18, 19, 20, 21, 22, 23, 24, 25, 26]])

但是,添加 b 并不那么优雅:
> np.append(a,b, axis=0)
ValueError: arrays must have same number of dimensions

从numpy手册得知,append存在的问题是:

"当指定轴时,值必须具有正确的形状。"

为了获得正确的结果,我需要先进行强制转换。

> np.append(a,b.reshape(1,9), axis=0)
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
   [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
   [18, 19, 20, 21, 22, 23, 24, 25, 26],
   [ 0,  1,  2,  3,  4,  5,  6,  7,  8]])

因此,在我的文件读取循环中,我目前正在使用类似于以下的if语句:
for i in [a, b]:
    if np.size(i.shape) == 2:
        result = np.append(result, i, axis=0)
    else:
        result = np.append(result, i.reshape(1,9), axis=0)

有没有一种方法可以不使用if语句来追加"a"和"b"?

编辑:虽然@Sven完美地回答了原始问题(使用np.atleast_2d()),但他(以及其他人)指出代码效率低下。在下面的答案中,我结合了他们的建议并替换了我的原始代码。现在应该更有效率了。谢谢。

4个回答

3
您可以使用numpy.atleast_2d()函数:
result = np.append(result, np.atleast_2d(i), axis=0)

注意,反复使用 numpy.append() 是构建 NumPy 数组非常低效的方式 -- 它必须在每一步重新分配。如果可能的话,请预先分配所需最终大小的数组,并使用切片后填充它。


谢谢您的快速回答。有趣的是,我不知道atleast_2d()方法,但它似乎有效。不过,我想您是指np.atleast_2d(i)吧?关于分配,我不知道最终大小,我还能做些什么来减少低效率吗? - Sebastian
1
如果性能很重要,您可以首先将要连接的所有数组收集到Python列表中,计算数组的最终大小,然后分配和填充数组。(如果使用此方法有任何问题,请提出新的问题。) - Sven Marnach
@Sven,关于存储数组的问题。在大多数情况下,这可能比我的建议更好(假设有足够的内存来存储所有的数组两次)。 - Henry Gomersall
@Sven @Henry,非常感谢你们的帮助和提示。如果我遇到速度问题,我会尝试使用Python列表方法。再见! - Sebastian
2
@Sebastian:另一个想法——也许在按照上述描述创建列表后调用numpy.vstack()是你最好的猜测。 - Sven Marnach

2
你可以将所有的数组添加到一个列表中,然后使用np.vstack()在最后将它们全部连接起来。这样可以避免每次追加时都不断重新分配增长数组的内存空间。
|1> a = np.arange(27).reshape(3,9)

|2> b = np.arange(9)

|3> np.vstack([a,b])
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
       [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23, 24, 25, 26],
       [ 0,  1,  2,  3,  4,  5,  6,  7,  8]])

是的,这似乎有效,@Sven也推荐了这样做。我不应该在我的for循环中调用result = np.vstack(result,i)的原因是效率低下,对吗? - Sebastian

1

我将在@Sven、@Henry和@Robert的帮助下改进我的代码。@Sven回答了这个问题,因此他为这个问题赢得了声誉,但正如他和其他人所强调的那样,有一种更有效的方法来实现我想要的。

这涉及使用Python列表,它允许使用性能惩罚为O(1)的附加操作,而numpy.append()的性能惩罚为O(N**2)。之后,将列表转换为numpy数组:

假设i是类型ab中的任意一个:

> a = np.arange(27).reshape(3,9)
> b = np.arange(9)
> a.shape
(3, 9)
> b.shape
(9,)

初始化列表并将所有读取的数据附加到其中,例如,如果数据按顺序出现为'aaba'。

> mList = []
> for i in [a,a,b,a]:
     mList.append(i)

你的mList将会长这样:

> mList
[array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
   [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
   [18, 19, 20, 21, 22, 23, 24, 25, 26]]),
 array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
   [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
   [18, 19, 20, 21, 22, 23, 24, 25, 26]]),
 array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
 array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
   [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
   [18, 19, 20, 21, 22, 23, 24, 25, 26]])]

最后,使用vstack将列表堆叠起来形成一个numpy数组:
> result = np.vstack(mList[:])
> result.shape
(10, 9)

再次感谢您宝贵的帮助。


使用np.vstack()时不需要np.atleast_2d(),但也没有什么坏处。 - Robert Kern
@Robert 哦,你说得对。太棒了,这样可以让它更紧凑。我会修改我的回答。 - Sebastian

0
如被指出的那样,每个NumPy数组都需要重新分配内存才能使用append函数。一个替代方案是只分配一次内存,可以像这样实现:
total_size = 0
for i in [a,b]:
    total_size += i.size

result = numpy.empty(total_size, dtype=a.dtype)
offset = 0
for i in [a,b]:
    # copy in the array
    result[offset:offset+i.size] = i.ravel()
    offset += i.size

# if you know its always divisible by 9:
result = result.reshape(result.size//9, 9)

如果无法预先计算数组大小,那么也许可以对大小设定一个上限,然后只需预分配足够大的一块空间即可。然后将结果作为该块内的视图即可。
result = result[0:known_final_size]

好的代码片段,谢谢。但是这需要我打开每个文件(几百个)两次,对吗?首先获取大小(循环1),其次提取数据(循环2)。将文件(指针?)存储在内存中可能也不是高效的方法(但我无法判断)。你觉得呢? - Sebastian
如果我不熟悉相关的库,请原谅我无法完全转移答案。如果问题在于每次加载文件,那么您能否从文件元数据中获取total_size(当然,假设它没有被压缩,文件大小将对数组大小设置一个上限)?关于时间差异,只需测量即可!(timeit模块使其非常简单)。顺便说一句,解决多次加载文件的最简单方法是按照@Sven在他的帖子中建议的那样,将所有数组加载到内存中。 - Henry Gomersall

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接