1) 在将数据输入会话之前,应该对成对数据进行配对采样。将每个成对标记为布尔值标签,例如y = 1表示匹配对,否则为0。
2) 只需计算每对的正/负项,并让0-1标签y选择哪个添加到损失中。
首先创建占位符,y_用于布尔标签。
dim = 64
x1_ = tf.placeholder('float32', shape=(None, dim))
x2_ = tf.placeholder('float32', shape=(None, dim))
y_ = tf.placeholder('uint8', shape=[None])
然后可以通过该函数创建损失张量。
def loss(x1, x2, y):
l2diff = tf.sqrt( tf.reduce_sum(tf.square(tf.sub(x1, x2)),
reduction_indices=1))
margin = tf.constant(1.)
labels = tf.to_float(y)
match_loss = tf.square(l2diff, 'match_term')
mismatch_loss = tf.maximum(0., tf.sub(margin, tf.square(l2diff)), 'mismatch_term')
loss = tf.add(tf.mul(labels, match_loss), \
tf.mul((1 - labels), mismatch_loss), 'loss_add')
loss_mean = tf.reduce_mean(loss)
return loss_mean
loss_ = loss(x1_, x2_, y_)
然后输入你的数据(例如随机生成的):
batchsize = 4
x1 = np.random.rand(batchsize, dim)
x2 = np.random.rand(batchsize, dim)
y = np.array([0,1,1,0])
l = sess.run(loss_, feed_dict={x1_:x1, x2_:x2, y_:y})