Tensorflow中的自定义中值池化

4

我正在尝试在tensorflow中实现中位数池化层。

然而,既没有tf.nn.median_pool,也没有tf.reduce_median

是否有一种方法可以使用Python API来实现这样的池化层?

5个回答

5
您可以使用类似以下的内容:
patches = tf.extract_image_patches(tensor, [1, k, k, 1], ...)
m_idx = int(k*k/2+1)
top = tf.top_k(patches, m_idx, sorted=True)
median = tf.slice(top, [0, 0, 0, m_idx-1], [-1, -1, -1, 1])

为了适应中位数核大小为偶数和多通道的情况,您需要进行扩展,但这应该可以帮助您完成大部分工作。

3

截至2017年3月,更简单的答案(在底层模式下与Alex建议的方式类似)是这样做:

patches = tf.extract_image_patches(x, [1, k, k, 1], [1, k, k, 1], 4*[1], 'VALID')
medians = tf.contrib.distributions.percentile(patches, 50, axis=3)

1
谢谢,不过这也是在“深度”维度上进行中位数池化。你知道有没有一种只在高度和宽度上进行池化的方法吗?(就像max_pooling2d一样) - cbodt

0
对我而言,Alex的答案在tf 1.4.1版本中不起作用。 tf.top_k 应该改为 tf.nn.top_k 并且应该在调用tf.nn.top_k时传递对应的值。
此外,如果输入是[1,H,W,C],两个答案都不能只处理高度和宽度而忽略通道。

0

我正在寻找一个适用于tensorflowjs的中值滤波器,但似乎找不到。tfa现在有一个中值滤波器,但对于tf.js,您可以使用这个。不确定它是否适用于nodegpu。

function medianFilter(x, filter, strides, pad) {
   //make Kernal
   //todo allow for filter as array or number
   let filterSize = filter ** 2;
   let locs = tf.range(0, filterSize, filterSize );
   //makes a bunc of arrays each one reprensentin one of the valuesin the median window ie 2x2 filter i in chanle and 4 out chanles
   let f = tf.oneHot(tf.range(0,filterSize,1, 'int32'), filterSize).reshape([filter, filter, 1, filterSize]);
   let y = tf.conv2d(x,f,strides,pad);
   let m_idx = Math.floor(filterSize/2)+1;

   let top = tf.topk(y, m_idx, true);
   //note that thse are 3d tensors and if you use 4d ones add a 0 and -1 infron like in above ansowers
   let median = tf.slice(top.values, [0,0,m_idx-1], [-1,-1,1] );

   return median;
}

0

通过一些附加的重塑操作,可以对通道逐个进行中位数池化:

# assuming NHWC layout
strides = rates = [1, 1, 1, 1]
patches = tf.extract_image_patches(x, [1, k, k, 1], strides, rates, 'VALID')
batch_size = tf.shape(x)[0]
n_channels = tf.shape(x)[-1]
n_patches_h = (tf.shape(x)[1] - k) // strides[1] + 1
n_patches_w = (tf.shape(x)[2] - k) // strides[2] + 1
n_patches = tf.shape(patches)[-1] // n_channels
patches = tf.reshape(patches, [batch_size, k, k, n_patches_h * n_patches_w, n_channels])
medians = tf.contrib.distributions.percentile(patches, 50, axis=[1,2])
medians = tf.reshape(medians, (batch_size, n_patches_h, n_patches_w, n_channels))

虽然不是很高效。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接