如何使用另一个数组的元素作为索引,对tensorflow中的张量进行切片?

3
我正在寻找类似于tf.unsorted_segment_sum的函数,但我不想对分段进行求和,而是想将每个分段作为张量获取。
例如,我有以下代码: (实际上,我有一个形状为(10000, 63)的张量,分段数为2500)
    to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
                            [0.3, 0.2, 0.2, 0.6, 0.3],
                            [0.9, 0.8, 0.7, 0.6, 0.5],
                            [2.0, 2.0, 2.0, 2.0, 2.0]])

indices = tf.constant([0, 2, 0, 1])
num_segments = 3
tf.unsorted_segment_sum(to_be_sliced, indices, num_segments)

输出结果将会显示在这里

array([sum(row1+row3), row4, row2]

我需要的是三个形状不同的张量(可能是张量列表),第一个包含原始数据的第一行和第三行(形状为(2,5)),第二个包含第四行(形状为(1,5)),第三个包含第二行,如下所示:
[array([[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.9, 0.8, 0.7, 0.6, 0.5]]),
 array([[2.0, 2.0, 2.0, 2.0, 2.0]]),
 array([[0.3, 0.2, 0.2, 0.6, 0.3]])]

提前致谢!

2个回答

0

对于您的情况,您可以在Tensorflow中使用Numpy切片。因此,这将起作用:

sliced_1 = to_be_sliced[:3, :]
# [[0.4 0.5 0.5 0.7 0.8]
#  [0.3 0.2 0.2 0.6 0.3]
#  [0.3 0.2 0.2 0.6 0.3]]
sliced_2 = to_be_sliced[3, :]
# [0.3 0.2 0.2 0.6 0.3]

或者更一般的选项,您可以按照以下方式进行:

to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
                        [0.3, 0.2, 0.2, 0.6, 0.3],
                        [0.9, 0.8, 0.7, 0.6, 0.5],
                        [2.0, 2.0, 2.0, 2.0, 2.0]])

first_tensor = tf.gather_nd(to_be_sliced, [[0], [2]])
second_tensor = tf.gather_nd(to_be_sliced, [[3]])
third_tensor = tf.gather_nd(to_be_sliced, [[1]])

concat = tf.concat([first_tensor, second_tensor, third_tensor], axis=0)

是的,这个例子是错误的,我的问题更加普遍。我会编辑问题。 - tusker

0
你可以这样做:
import tensorflow as tf

to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
                            [0.3, 0.2, 0.2, 0.6, 0.3],
                            [0.9, 0.8, 0.7, 0.6, 0.5],
                            [2.0, 2.0, 2.0, 2.0, 2.0]])
indices = tf.constant([0, 2, 0, 1])
num_segments = 3
result = [tf.boolean_mask(to_be_sliced, tf.equal(indices, i)) for i in range(num_segments)]
with tf.Session() as sess:
    print(*sess.run(result), sep='\n')

输出:

[[0.1 0.2 0.3 0.4 0.5]
 [0.9 0.8 0.7 0.6 0.5]]
[[2. 2. 2. 2. 2.]]
[[0.3 0.2 0.2 0.6 0.3]]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接