计算编辑距离(feed_dict错误)

5

我在Tensorflow中编写了一些代码,用于计算一个字符串与一组字符串之间的编辑距离。但是我无法找出错误所在。

import tensorflow as tf
sess = tf.Session()

# Create input data
test_string = ['foo']
ref_strings = ['food', 'bar']

def create_sparse_vec(word_list):
    num_words = len(word_list)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
    chars = list(''.join(word_list))
    return(tf.SparseTensor(indices, chars, [num_words,1,1]))


test_string_sparse = create_sparse_vec(test_string*len(ref_strings))
ref_string_sparse = create_sparse_vec(ref_strings)

sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True))

这段代码有效,运行后会产生以下输出:

array([[ 0.25],
       [ 1.  ]], dtype=float32)

但是当我试图通过稀疏占位符输入稀疏张量时,出现错误。

test_input = tf.sparse_placeholder(dtype=tf.string)
ref_input = tf.sparse_placeholder(dtype=tf.string)

edit_distances = tf.edit_distance(test_input, ref_input, normalize=True)

feed_dict = {test_input: test_string_sparse,
             ref_input: ref_string_sparse}

sess.run(edit_distances, feed_dict=feed_dict)

以下是错误追踪信息:

Traceback (most recent call last):

  File "<ipython-input-29-4e06de0b7af3>", line 1, in <module>
    sess.run(edit_distances, feed_dict=feed_dict)

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run
    for subfeed, subfeed_val in _feed_fn(feed, feed_val):

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn
    return feed_fn(feed, feed_val)

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda>
    [feed.indices, feed.values, feed.shape], feed_val)),

TypeError: zip argument #2 must support iteration

这里发生了什么事情,有任何想法吗?


错误可能来自于test_string_parseref_string_parse的值,你能提供它们的创建代码吗? - Olivier Moindrot
1个回答

4
TL;DR: 在create_sparse_vec()的返回类型中,使用tf.SparseTensorValue而不是tf.SparseTensor
问题在于create_sparse_vec()的返回类型为tf.SparseTensor,在调用sess.run()时,它不能被解析为输入的值。
当你输入一个(密集的)tf.Tensor时,期望的值类型是NumPy数组(或能转换为数组的某些对象)。 当你输入一个tf.SparseTensor时,期望的值类型是tf.SparseTensorValue,它类似于一个tf.SparseTensor,但其indices、values和shape属性是NumPy数组(或像您的示例中的列表一样可以转换为数组的某些对象)。
下面的代码应该可以工作:
def create_sparse_vec(word_list):
    num_words = len(word_list)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
    chars = list(''.join(word_list))
    return tf.SparseTensorValue(indices, chars, [num_words,1,1])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接