Tensorflow中的笛卡尔积

9

在Tensorflow中是否有类似于itertools.product的简便方法来进行笛卡尔积操作?我想要获取两个张量(ab)元素的组合,在Python中可以使用itertools通过list(product(a, b))实现。我正在寻找Tensorflow中的替代方法。

5个回答

11

我在这里假设ab都是1维张量。

要获得两者的笛卡尔积,我会使用tf.expand_dimstf.tile的组合:

a = tf.constant([1,2,3]) 
b = tf.constant([4,5,6,7]) 

tile_a = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[0]])  
tile_a = tf.expand_dims(tile_a, 2) 
tile_b = tf.tile(tf.expand_dims(b, 0), [tf.shape(a)[0], 1]) 
tile_b = tf.expand_dims(tile_b, 2) 

cartesian_product = tf.concat([tile_a, tile_b], axis=2) 

cart = tf.Session().run(cartesian_product) 

print(cart.shape) 
print(cart) 

最终得到一个长度为 len(a) * len(b) * 2 的张量,其中最后一维表示了元素对 (a, b) 的每个组合。


7

使用广播的相同较短解决方案,使用 tf.add() (已测试):

import tensorflow as tf

a = tf.constant([1,2,3]) 
b = tf.constant([4,5,6,7]) 

a, b = a[ None, :, None ], b[ :, None, None ]
cartesian_product = tf.concat( [ a + tf.zeros_like( b ),
                                 tf.zeros_like( a ) + b ], axis = 2 )

with tf.Session() as sess:
    print( sess.run( cartesian_product ) )

将输出:

[[[1 4]
[2 4]
[3 4]]

[[1 5]
[2 5]
[3 5]]

[[1 6]
[2 6]
[3 6]]

[[1 7]
[2 7]
[3 7]]]


2
这个答案很棒!虽然可能不太易读,但更具有普适性。在第一维度后,似乎适用于任意维度的张量。例如,对于以下输入,它仍然会返回您所期望的结果: a = tf.constant([[[1,2,3],[4,5,6]],[[1,1,1],[1,1,1]]]) b = tf.constant([[[7,8,9],[10,11,12]]]) - marko

3
import tensorflow as tf

a = tf.constant([0, 1, 2])
b = tf.constant([2, 3])
c = tf.stack(tf.meshgrid(a, b, indexing='ij'), axis=-1)
c = tf.reshape(c, (-1, 2))
with tf.Session() as sess:
    print(sess.run(c))

输出:

[[0 2]
 [0 3]
 [1 2]
 [1 3]
 [2 2]
 [2 3]]

感谢jdehesa提供的贡献:链接


1

Sunreef的更简洁的回答使用tf.stack而不是tf.concat

a = tf.constant([1,2,3]) 
b = tf.constant([4,5,6,7]) 

tile_a = tf.tile(tf.expand_dims(a, 1), [1, tf.shape(b)[0]]) 
tile_b = tf.tile(tf.expand_dims(b, 0), [tf.shape(a)[0], 1])  
ans = tf.stack([tile_a, tile_b], -1)

0
我受到Jaba答案的启发。如果你想要得到两个二维张量的笛卡尔积,可以按照以下方式进行:
输入a:[N,L]和b:[M,L],获取[N*M,L]连接张量。
tile_a = tf.tile(tf.expand_dims(a, 1), [1, M, 1])  
tile_b = tf.tile(tf.expand_dims(b, 0), [N, 1, 1]) 

cartesian_product = tf.concat([tile_a, tile_b], axis=2)   
cartesian = tf.reshape(cartesian_product, [N*M, -1])

cart = tf.Session().run(cartesian) 

print(cart.shape)
print(cart) 

如果a或b为空,这段代码将无效,但是您可以通过在[N * M,-1]中使用a.shape[-1] + b.shape[-1]来修复它。 - joel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接