使用Numpy或Tensorflow将线性数组向量化创建一个由对角方阵组成的数组

3
我有一个形状为[batch_size, N]的数组,例如:
[[1  2]
 [3  4]
 [5  6]]

我需要创建一个形状为[batch_size, N, N]的3个索引数组,其中对于每个batch,我有一个N x N的对角线矩阵,其中对角线由相应的batch元素取得,例如在这种情况下,我要找的结果是:
[
  [[1,0],[0,2]],
  [[3,0],[0,4]],
  [[5,0],[0,6]],
]

如何在不使用for循环和向量化的情况下进行此操作?我猜这是一个维度扩展问题,但我找不到正确的函数来实现它。(我需要这个功能,因为我正在使用tensorflow并在numpy中进行原型设计。)
5个回答

2
使用 np.expand_dimsnp.eye 的逐元素乘积。
a = np.array([[1,  2],
              [3,  4],
              [5, 6]])
N = a.shape[1]
a = np.expand_dims(a, axis=1)
a*np.eye(N)

array([[[1., 0.],
       [0., 2.]],

      [[3., 0.],
       [0., 4.]],

      [[5., 0.],
       [0., 6.]]])

解释

np.expand_dims(a, axis=1)会在a中添加一个新的维度,现在a将成为一个(3, 1, 2)的ndarray:

array([[[1, 2]],

       [[3, 4]],

       [[5, 6]]])

现在,您可以使用大小为N的单位矩阵,通过np.eye生成,将该数组进行乘法运算:

np.eye(N)
array([[1., 0.],
       [0., 1.]])

这将产生所需的输出:

a*np.eye(N)

array([[[1., 0.],
        [0., 2.]],

       [[3., 0.],
        [0., 4.]],

       [[5., 0.],
        [0., 6.]]])

2

在TensorFlow中尝试一下:

import tensorflow as tf
A = [[1,2],[3 ,4],[5,6]]
B = tf.matrix_diag(A)
print(B.eval(session=tf.Session()))
[[[1 0]
  [0 2]]

 [[3 0]
  [0 4]]

 [[5 0]
  [0 6]]]

2

方法 #1

这是一个向量化的方法,使用 np.einsum 处理输入数组 a -

# Initialize o/p array
out = np.zeros(a.shape + (a.shape[1],),dtype=a.dtype)

# Get diagonal view and assign into it input array values
diag = np.einsum('ijj->ij',out)
diag[:] = a

方法二

另一种基于切片的赋值方法 -

m,n = a.shape
out = np.zeros((m,n,n),dtype=a.dtype)
out.reshape(-1,n**2)[...,::n+1] = a

1
你可以使用 numpy.diag
m = [[1, 2],
 [3, 4],
 [5, 6]]

[np.diag(b) for b in m]

编辑:下图显示了上述解决方案的平均执行时间(实线),并将其与@Divakar的执行时间(虚线)进行了比较,针对不同的批处理大小和不同的矩阵大小。

enter image description here

我不相信你会得到很大的改进,但这仅基于这个简单的度量标准。


我同意这是最简单的解决方案,但我想知道是否可以通过像einsum这样的函数使它高度向量化并在一行中完成。 - linello
另一个区别是你的输出是列表,而不是数组,所以不适合进行进一步的向量计算。 - B. M.

0

你基本上想要一个函数,它可以执行与 np.block(..) 相反/翻转的操作。

我也需要同样的功能,所以我写了这个小函数:

def split_blocks(x, m=2, n=2):
    """
    Reverse the action of np.block(..)

    >>> x = np.random.uniform(-1, 1, (2, 18, 20))
    >>> assert (np.block(split_blocks(x, 3, 4)) == x).all()

    :param x: (.., M, N) input matrix to split into blocks
    :param m: number of row splits
    :param n: number of column, splits
    :return:
    """
    x = np.array(x, copy=False)
    nd = x.ndim
    *shape, nr, nc = x.shape
    return list(map(list, x.reshape((*shape, m, nr//m, n, nc//n)).transpose(nd-2, nd, *range(nd-2), nd-1, nd+1)))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接