将特征图(3D数组)拆分为2D数组。

3

假设我有一个特征图(即3D数组)的形状为(32, 32, 96)

In [573]: feature_map = np.random.randint(low=0, high=255, size=(32, 32, 96))

现在,我想单独可视化每个特征图。因此,我想提取每个前面的切片(即形状为(32, 32)的2D数组),这样应该会得到96个这样的特征图。
如何获取这些数组,可能不需要复制以节省内存?由于这仅用于可视化,因此一个视图就足够了!
3个回答

6
你可以使用 np.transpose 和切片操作(而不是创建数组的副本):
feature_map = np.random.randint(low=0, high=255, size=(32, 32, 96))
feature_map = np.transpose(feature_map, axes=[2, 0, 1])
for i in range(feature_map.shape[0]):
  print(feature_map[i].shape)  # a view of original array. shape=(32, 32)

...或仅仅做切片:

for i in range(feature_map.shape[2]):
  print(feature_map[:, :, i].shape)  # a view of original array. shape=(32, 32)

为什么需要“转置”?我认为feature_map[..., i]应该可以胜任这项工作,对吗? - kmario23
1
是的,两个版本都可以。事实上,我更喜欢第二个版本。 - Maxim
谢谢!还请查看我的答案 ;) - kmario23

0
import numpy as np

def do_something(array_slice):
    print array_slice

feature_map = np.random.randint(low=0, high=255, size=(3, 3, 9))

# loop over the indices of the last dimension of the array (i.e. 0 to 8)
for level in range(feature_map.shape[2]):
    # now take only the 2d-slice of the first two dimensions at the height of 'level'
    do_something(feature_map[:,:,level])

# you could also take a slice from another dimension
for level in range(feature_map.shape[1]):    
    do_something(feature_map[:,level,:])

虽然这段代码可能回答了问题,但提供有关它如何以及/或为什么解决问题的附加上下文将改善答案的长期价值。请阅读此如何回答以提供高质量的答案。 - thewaywewere

0

我还意识到numpy.dsplit()可以用于这样的三维数组,因为我们试图沿深度方向分割它。但是,我还需要使用np.squeeze()来消除第三个维度。此外,根据我的情况需要,它还返回一个数组的视图

# splitting it into 96 slices in one-go!
In [659]: np.dsplit(feature_map, feature_map.shape[-1])

In [660]: np.dsplit(feature_map, feature_map.shape[-1])[10].flags
Out[660]: 
  C_CONTIGUOUS : False
  F_CONTIGUOUS : False
  OWNDATA : False   #<============== NO copy is made
  WRITEABLE : True
  ALIGNED : True
  UPDATEIFCOPY : False

In [661]: np.dsplit(feature_map, feature_map.shape[-1])[10].shape
Out[661]: (32, 32, 1)

# getting rid of unitary dimension with `np.squeeze`
In [662]: np.squeeze(np.dsplit(feature_map, feature_map.shape[-1])[10]).shape
Out[662]: (32, 32)

In [663]: np.squeeze(np.dsplit(feature_map, feature_map.shape[-1])[10]).flags
Out[663]: 
  C_CONTIGUOUS : False
  F_CONTIGUOUS : False
  OWNDATA : False   #<============== NO copy is made
  WRITEABLE : True
  ALIGNED : True
  UPDATEIFCOPY : False

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接