Tensorflow,获取矩阵每行非零值的索引

3

我希望能够获取矩阵中每行非零值的索引。我尝试使用tf.where,但输出结果并非我所期望的。

我现在的代码如下:

b = tf.constant([[1,0,0,0,0],
                 [1,0,1,0,1]],dtype=tf.float32)
zero = tf.constant(0, dtype=tf.float32)
where = tf.not_equal(b, zero)
indices = tf.where(where)

索引输出结果如下:

<tf.Tensor: id=136, shape=(4, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 0],
       [1, 2],
       [1, 4]])>

但是我希望输出的结果是:

[[0],
 [0,2,4]]

我有一个列表,每一行都有其索引。

谢谢。


1
问题在于结果不是一个合适的张量,因为每一行的列数不相同。如果您想要这样的结果,可以将该结果“填充”以使所有行具有相同的长度(例如使用-1进行填充,这是一个无效值,或者使用另一个向量指示每行上的有效索引数),或者使用不规则张量(更复杂,但实际上可以表示您想要的内容)。 - jdehesa
1个回答

1

这不可能是一个合适的张量,因为维度不统一。如果您愿意使用不规则张量,您可以执行以下操作:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    b = tf.constant([[1, 0, 0, 0, 0],
                     [1, 0, 1, 0, 1]],dtype=tf.float32)
    num_rows = tf.shape(b)[0]
    zero = tf.constant(0, dtype=tf.float32)
    where = tf.not_equal(b, zero)
    indices = tf.where(where)
    s = tf.ragged.segment_ids_to_row_splits(indices[:, 0], num_rows)
    row_start = s[:-1]
    elem_per_row = s[1:] - row_start
    idx = tf.expand_dims(row_start, 1) + tf.ragged.range(elem_per_row)
    result = tf.gather(indices[:, 1], idx)
    print(sess.run(result))
    # <tf.RaggedTensorValue [[0], [0, 2, 4]]>

编辑:如果您不想或无法使用不规则张量,则可以尝试另一种方法。您可以生成一个填充有“无效”值的张量。您可以在这些无效值中使用例如-1,也可以只是有一个1D张量,告诉您每行有多少个有效值:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    b = tf.constant([[1, 0, 0, 0, 0],
                     [1, 0, 1, 0, 1]],dtype=tf.float32)
    num_rows = tf.shape(b)[0]
    zero = tf.constant(0, dtype=tf.float32)
    where = tf.not_equal(b, zero)
    indices = tf.where(where)
    num_indices = tf.shape(indices)[0]
    elem_per_row = tf.bincount(tf.cast(indices[:, 0], tf.int32), minlength=num_rows)
    row_start = tf.concat([[0], tf.cumsum(elem_per_row[:-1])], axis=0)
    max_elem_per_row = tf.reduce_max(elem_per_row)
    r = tf.range(max_elem_per_row)
    idx = tf.expand_dims(row_start, 1) + r
    idx = tf.minimum(idx, num_indices - 1)
    result = tf.gather(indices[:, 1], idx)
    # Optional: replace invalid elements with -1
    result = tf.where(tf.expand_dims(elem_per_row, 1) > r, result, -tf.ones_like(result))
    print(sess.run(result))
    # [[ 0 -1 -1]
    #  [ 0  2  4]]
    print(sess.run(elem_per_row))
    # [1 3]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接