我希望能够在每个像素的深度通道上映射一个TensorFlow函数,该函数对应于具有尺寸[batch_size, H, W, n_channels]的矩阵中的每个向量。换句话说,对于批次中的每个大小为H x W的图像:
我认为
您有任何想法,为什么会出现这种情况,以及我应该如何构造我的代码以避免错误?
这是我当前的函数实现:
特别地,错误如下:
整个错误堆栈和代码可以在这里找到。 感谢帮助, G.
- 我提取一些功能图F_k(其数量为n_channels),具有相同的大小H x W(因此,所有功能图共同形成一个形状的张量[H,W,n_channels];
- 然后,我希望将自定义函数应用于与每个特征图F_k的第i行和第j列相关联的向量v_ij,但是完全探索了深度通道(例如,v的维数为[1 x 1 x n_channels])。理想情况下,所有这些都会同时发生。
我认为
tf.map_fn()
可能是一个选项,我尝试了以下解决方案,其中我递归使用tf.map_fn()
来访问与每个像素相关联的特征。然而,这似乎有些次优,最重要的是,在尝试反向传播梯度时会引发错误。您有任何想法,为什么会出现这种情况,以及我应该如何构造我的代码以避免错误?
这是我当前的函数实现:
import tensorflow as tf
from tensorflow import layers
def apply_function_on_pixel_features(incoming):
# at first the input is [None, W, H, n_channels]
if len(incoming.get_shape()) > 1:
return tf.map_fn(lambda x: apply_function_on_pixel_features(x), incoming)
else:
# here the input is [n_channels]
# apply some function that applies a transfomration and returns a vetor of the same size
output = my_custom_fun(incoming) # my_custom_fun() doesn't change the shape
return output
和我的代码主体:
H = 128
W = 132
n_channels = 8
x1 = tf.placeholder(tf.float32, [None, H, W, 1])
x2 = layers.conv2d(x1, filters=n_channels, kernel_size=3, padding='same')
# now apply a function to the features vector associated to each pixel
x3 = apply_function_on_pixel_features(x2)
x4 = tf.nn.softmax(x3)
loss = cross_entropy(x4, labels)
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.minimize(loss) # <--- ERROR HERE!
特别地,错误如下:
File "/home/venvs/tensorflowGPU/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2481, in AddOp
self._AddOpInternal(op)
File "/home/venvs/tensorflowGPU/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2509, in _AddOpInternal
self._MaybeAddControlDependency(op)
File "/home/venvs/tensorflowGPU/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2547, in _MaybeAddControlDependency
op._add_control_input(self.GetControlPivot().op)
AttributeError: 'NoneType' object has no attribute 'op'
整个错误堆栈和代码可以在这里找到。 感谢帮助, G.
更新:
根据@thushv89的建议,我添加了一个可能的解决方案来解决问题。我仍然不知道为什么我的先前代码不起作用。任何关于此的见解仍将非常感激。