例如,
然而,这个对象实际上是形状为
我们如何在批次维度上广播元组索引?
np.array([[1,2],[3,4]])[np.triu_indices(2)]
的形状为(3,)
,是上三角元素展平后的列表。但是,如果我有一个2x2矩阵的批量:foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)
如果我想获取每个矩阵的上三角索引,尝试的朴素方法是:
foo[:,np.triu_indices(2)]
然而,这个对象实际上是形状为
(30,2,3,2)
的(而不是如果我们按批次提取上三角条目时可能期望的(30,3)
)。我们如何在批次维度上广播元组索引?
x, y = np.triu_indices(2); foo[:,x,y]
。 - Mazdak