C语言中的滚动中位数算法

128

我目前正在编写一个用于在 C 中实现滚动中位数滤波器(类似于滚动平均滤波器)的算法。通过查阅文献,似乎有两种比较有效的方法来实现它。第一种方法是对初始值窗口进行排序,然后执行二分搜索以在每次迭代时插入新值并删除现有值。

第二种方法(来自 Hardle 和 Steiger,在 1995 年的 JRSS-C,算法 296 中)构建了一个双端堆结构,其中一个端点是最大堆,另一个端点是最小堆,并且中位数位于中间。这会产生一个线性时间算法,而不是 O(n log n) 的算法。

我的问题在于:虽然第一种方法可行,但我需要在数百万个时间序列上运行此程序,因此效率非常重要。而实现第二种方法却很困难。我在 R 统计软件包的 stats 包中的 Trunmed.c 文件中找到了代码,但它相当难以理解。

有人知道一个良好编写的 C 实现线性时间滚动中位数算法吗?

编辑:Trunmed.c 代码链接http://google.com/codesearch/p?hl=en&sa=N&cd=1&ct=rc#mYw3h_Lb_e0/R-2.2.0/src/library/stats/src/Trunmed.c


刚刚实现了移动平均值...移动中位数有点棘手。试着谷歌一下移动中位数。 - hookenz
每个时间序列中有多少个数字?即使有百万个时间序列,如果你只有几千个数字,如果你的代码写得高效,运行时间可能不会超过一两分钟。 - Dana the Sane
那段代码的参考资料太古老了!R 2.2.0已经过去三年了,我们现在使用的是R 2.9.1,而且预计将于9月24日发布2.9.2版本,10月份发布R 2.10.0。 - Dirk Eddelbuettel
18
两堆解法为什么是线性的?因为它的时间复杂度为O(n log k),其中k代表窗口大小,而堆的删除操作是O(log k)的。 - yairchu
5
一些实现和比较:https://github.com/suomela/median-filter - Jukka Suomela
显示剩余2条评论
13个回答

30

我曾多次查看 R 语言的 src/library/stats/src/Trunmed.c,因为我想在一个独立的 C++ 类 / C 子程序中实现类似的功能。请注意,这实际上是两个实现,可以参见 src/library/stats/man/runmed.Rd(帮助文件的源)。

\details{
  Apart from the end values, the result \code{y = runmed(x, k)} simply has
  \code{y[j] = median(x[(j-k2):(j+k2)])} (k = 2*k2+1), computed very
  efficiently.

  The two algorithms are internally entirely different:
  \describe{
    \item{"Turlach"}{is the Härdle-Steiger
      algorithm (see Ref.) as implemented by Berwin Turlach.
      A tree algorithm is used, ensuring performance \eqn{O(n \log
        k)}{O(n * log(k))} where \code{n <- length(x)} which is
      asymptotically optimal.}
    \item{"Stuetzle"}{is the (older) Stuetzle-Friedman implementation
      which makes use of median \emph{updating} when one observation
      enters and one leaves the smoothing window.  While this performs as
      \eqn{O(n \times k)}{O(n * k)} which is slower asymptotically, it is
      considerably faster for small \eqn{k} or \eqn{n}.}
  }
}

希望能够将此代码以更独立的方式重新使用。你愿意吗? 我可以协助一些 R 方面的内容。

编辑1:除了上述 Trunmed.c 旧版本链接之外,这里还有以下的 SVN 版本:

编辑2: Ryan Tibshirani 在 快速中位数分箱 中有一些C和Fortran代码,可能是窗口方法的合适起点。


谢谢Dirk。一旦我得到一个干净的解决方案,我计划在GPL下发布它。我也有兴趣建立R和Python接口。 - AWB
9
这个想法最终发展成了什么?你是否将解决方案纳入软件包中了? - Xu Wang

28

我找不到一种带有顺序统计的现代化实现的c++数据结构,因此最终在MAK建议的top coders链接中实现了两个想法(比赛编者按:向下滚动到FloatingMedian)。

两个multiset

第一个想法将数据分成两个数据结构(堆、multiset等),每次插入/删除的复杂度为O(ln N),但是不允许动态更改分位数而不付出巨大代价。也就是说,我们可以有一个滚动的中位数或滚动的75%,但不能同时有。

区间树

第二个想法使用区间树,插入/删除/查询的复杂度为O(ln N),但更加灵活。最好的是,“N”是您的数据范围的大小。因此,如果您的滚动中位数有100万个项目的窗口,但您的数据变化范围为1..65536,则每移动100万个滚动窗口只需要执行16个操作!

c++代码类似于Denis上面发布的代码(“这是一种用于量化数据的简单算法”)。

GNU顺序统计树

就在放弃之前,我发现stdlibc++包含顺序统计树!!!

这些树有两个关键操作:

iter = tree.find_by_order(value)
order = tree.order_of_key(value)

请参阅libstdc++手册policy_based_data_structures_test(搜索“split and join”)。

我已经将该树包装成一个方便的头文件,以供支持c++0x/c++11风格的部分typedef的编译器使用:

#if !defined(GNU_ORDER_STATISTIC_SET_H)
#define GNU_ORDER_STATISTIC_SET_H
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

// A red-black tree table storing ints and their order
// statistics. Note that since the tree uses
// tree_order_statistics_node_update as its update policy, then it
// includes its methods by_order and order_of_key.
template <typename T>
using t_order_statistic_set = __gnu_pbds::tree<
                                  T,
                                  __gnu_pbds::null_type,
                                  std::less<T>,
                                  __gnu_pbds::rb_tree_tag,
                                  // This policy updates nodes'  metadata for order statistics.
                                  __gnu_pbds::tree_order_statistics_node_update>;

#endif //GNU_ORDER_STATISTIC_SET_H

实际上,由于设计原因,libstdc++扩展容器不允许多个值!正如我的名称(t_order_statistic_set)所示,多个值被合并。 所以,对于我们的目的,它们需要更多的工作 :-( - Leo Goodstadt
我们需要: 1)制作一个值到计数的映射(而不是集合) 2)分支大小应反映键的计数 (libstdc++-v3/include/ext/pb_ds/detail/tree_policy/order_statistics_imp.hpp) 从树继承,并且 3)重载insert()以增加计数/如果该值已经存在,则调用update_to_top() 4)重载erase()以减少计数/如果该值不唯一,则调用update_to_top() (请参见libstdc++-v3/include/ext/pb_ds/detail/rb_tree_map_/rb_tree_.hpp)有志愿者吗? - Leo Goodstadt

18

我已经在这里做了一个C实现更多细节请参见此问题:Rolling median in C - Turlach implementation

使用样例:

int main(int argc, char* argv[])
{
   int i, v;
   Mediator* m = MediatorNew(15);
 
   for (i=0; i<30; i++) {
      v = rand() & 127;
      printf("Inserting %3d \n", v);
      MediatorInsert(m, v);
      v = MediatorMedian(m);
      printf("Median = %3d.\n\n", v);
      ShowTree(m);
   }
}

6
基于最小中位数最大堆的实现快速、清晰,非常出色。做得很好。 - Johannes Rudolph
1
我该如何找到此解决方案的Java版本? - Hengameh
@AShelly,我从我的Go库中调用你的C实现比我的纯Go实现要慢,我不知道是因为:A. CGo的开销 B. 还是我的Go实现比你的C更快 C. 或者我忘记释放一些内存https://github.com/JaderDias/movingmedian/commit/9bd9c4f62210a101e9eee2ff85921ef253193858 - Jader Dias
在搜索算法描述时:Turlach实现了[Härdle,W. Steiger。Optimal Median Smoothing(1994)](https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.993)算法。 - maxschlepzig

15
我使用这个增量中位数估计器:
median += eta * sgn(sample - median)

它的形式与更常见的均值估计器相同:

mean += eta * (sample - mean)

这里的eta是一个小的学习率参数(例如0.001),sgn()是符号函数,返回值为{-1, 0, 1}中的一个。(如果数据是非稳态的且想要跟踪随时间发生的变化,请使用类似以下的常数eta;否则,对于稳态源,请使用像eta = 1 / n这样的东西来收敛,其中n是到目前为止所观察到的样本数。)

此外,我修改了中位数估计器,使其适用于任意的分位数。一般来说,分位函数告诉您将数据分成两个分数:p1-p 的值。以下是增量地估计该值的方法:

quantile += eta * (sgn(sample - quantile) + 2.0 * p - 1.0)

变量 p 的取值范围应在[0, 1]之间。这实际上将sgn()函数对称输出的{-1, 0, 1}向一侧倾斜,将数据样本分成两个大小不等的箱(分别是数据的分位数估计小于/大于比例p1 - p)。请注意,当p=0.5时,这就相当于中位数估计。


2
很棒,这里有一个修改版,它根据运行平均值调整 'eta'...(平均值用作中位数的粗略估计,因此它以相同的速率收敛于大值和小值)。即 eta 自动调整。https://dev59.com/2mbWa4cB1Zd3GeqPTRt2#15150968 - Jeff McClintock
3
如需了解类似技术,请参阅有关节约流技术的这篇论文:http://arxiv.org/pdf/1407.1121v1.pdf。它可以估计任何四分位数,并适应均值的变化。它只需要存储两个值:上次估计和上次调整方向(+1或-1)。该算法易于实现。我发现误差在97%的情况下仅在5%之内。 - Paul Chernoch

9
这里有一个针对量化数据(几个月后)的简单算法:
""" median1.py: moving median 1d for quantized, e.g. 8-bit data

Method: cache the median, so that wider windows are faster.
    The code is simple -- no heaps, no trees.

Keywords: median filter, moving median, running median, numpy, scipy

See Perreault + Hebert, Median Filtering in Constant Time, 2007,
    http://nomis80.org/ctmf.html: nice 6-page paper and C code,
    mainly for 2d images

Example:
    y = medians( x, window=window, nlevel=nlevel )
    uses:
    med = Median1( nlevel, window, counts=np.bincount( x[0:window] ))
    med.addsub( +, - )  -- see the picture in Perreault
    m = med.median()  -- using cached m, summ

How it works:
    picture nlevel=8, window=3 -- 3 1s in an array of 8 counters:
        counts: . 1 . . 1 . 1 .
        sums:   0 1 1 1 2 2 3 3
                        ^ sums[3] < 2 <= sums[4] <=> median 4
        addsub( 0, 1 )  m, summ stay the same
        addsub( 5, 1 )  slide right
        addsub( 5, 6 )  slide left

Updating `counts` in an `addsub` is trivial, updating `sums` is not.
But we can cache the previous median `m` and the sum to m `summ`.
The less often the median changes, the faster;
so fewer levels or *wider* windows are faster.
(Like any cache, run time varies a lot, depending on the input.)

See also:
    scipy.signal.medfilt -- runtime roughly ~ window size
    https://dev59.com/73M_5IYBdhLWcg3wmkUK

"""

from __future__ import division
import numpy as np  # bincount, pad0

__date__ = "2009-10-27 oct"
__author_email__ = "denis-bz-py at t-online dot de"


#...............................................................................
class Median1:
    """ moving median 1d for quantized, e.g. 8-bit data """

    def __init__( s, nlevel, window, counts ):
        s.nlevel = nlevel  # >= len(counts)
        s.window = window  # == sum(counts)
        s.half = (window // 2) + 1  # odd or even
        s.setcounts( counts )

    def median( s ):
        """ step up or down until sum cnt to m-1 < half <= sum to m """
        if s.summ - s.cnt[s.m] < s.half <= s.summ:
            return s.m
        j, sumj = s.m, s.summ
        if sumj <= s.half:
            while j < s.nlevel - 1:
                j += 1
                sumj += s.cnt[j]
                # print "j sumj:", j, sumj
                if sumj - s.cnt[j] < s.half <= sumj:  break
        else:
            while j > 0:
                sumj -= s.cnt[j]
                j -= 1
                # print "j sumj:", j, sumj
                if sumj - s.cnt[j] < s.half <= sumj:  break
        s.m, s.summ = j, sumj
        return s.m

    def addsub( s, add, sub ):
        s.cnt[add] += 1
        s.cnt[sub] -= 1
        assert s.cnt[sub] >= 0, (add, sub)
        if add <= s.m:
            s.summ += 1
        if sub <= s.m:
            s.summ -= 1

    def setcounts( s, counts ):
        assert len(counts) <= s.nlevel, (len(counts), s.nlevel)
        if len(counts) < s.nlevel:
            counts = pad0__( counts, s.nlevel )  # numpy array / list
        sumcounts = sum(counts)
        assert sumcounts == s.window, (sumcounts, s.window)
        s.cnt = counts
        s.slowmedian()

    def slowmedian( s ):
        j, sumj = -1, 0
        while sumj < s.half:
            j += 1
            sumj += s.cnt[j]
        s.m, s.summ = j, sumj

    def __str__( s ):
        return ("median %d: " % s.m) + \
            "".join([ (" ." if c == 0 else "%2d" % c) for c in s.cnt ])

#...............................................................................
def medianfilter( x, window, nlevel=256 ):
    """ moving medians, y[j] = median( x[j:j+window] )
        -> a shorter list, len(y) = len(x) - window + 1
    """
    assert len(x) >= window, (len(x), window)
    # np.clip( x, 0, nlevel-1, out=x )
        # cf http://scipy.org/Cookbook/Rebinning
    cnt = np.bincount( x[0:window] )
    med = Median1( nlevel=nlevel, window=window, counts=cnt )
    y = (len(x) - window + 1) * [0]
    y[0] = med.median()
    for j in xrange( len(x) - window ):
        med.addsub( x[j+window], x[j] )
        y[j+1] = med.median()
    return y  # list
    # return np.array( y )

def pad0__( x, tolen ):
    """ pad x with 0 s, numpy array or list """
    n = tolen - len(x)
    if n > 0:
        try:
            x = np.r_[ x, np.zeros( n, dtype=x[0].dtype )]
        except NameError:
            x += n * [0]
    return x

#...............................................................................
if __name__ == "__main__":
    Len = 10000
    window = 3
    nlevel = 256
    period = 100

    np.set_printoptions( 2, threshold=100, edgeitems=10 )
    # print medians( np.arange(3), 3 )

    sinwave = (np.sin( 2 * np.pi * np.arange(Len) / period )
        + 1) * (nlevel-1) / 2
    x = np.asarray( sinwave, int )
    print "x:", x
    for window in ( 3, 31, 63, 127, 255 ):
        if window > Len:  continue
        print "medianfilter: Len=%d window=%d nlevel=%d:" % (Len, window, nlevel)
            y = medianfilter( x, window=window, nlevel=nlevel )
        print np.array( y )

# end median1.py

4
滚动中位数可以通过维护两个数字分区来找到。
使用Min Heap和Max Heap来维护分区。
Max Heap将包含小于等于中位数的数字。
Min Heap将包含大于等于中位数的数字。
平衡约束:如果元素总数为偶数,则两个堆都应具有相等的元素。
如果元素总数为奇数,则Max Heap将比Min Heap多一个元素。
中位数元素:如果两个分区具有相同数量的元素,则中位数将是第一个分区的最大元素和第二个分区的最小元素的和的一半。
否则,中位数将是第一个分区的最大元素。
算法-
1-获取两个堆(1 Min Heap和1 Max Heap) Max Heap将包含前一半数量的元素 Min Heap将包含后一半数量的元素
2-将流中的新数字与Max Heap的top进行比较, 如果它小于或等于在max heap中添加该数字。 否则在Min Heap中添加数字。
3-如果min Heap中的元素比Max Heap更多 然后从Min Heap中删除顶部元素并添加到Max Heap中。 如果Max Heap中的元素多于Min Heap中的一个,则在Min Heap中删除顶部元素 并将其添加到Max Heap中。
4-如果两个堆具有相等的元素,则 中位数将是Max Heap和Min Heap的最小元素的和的一半。 否则,中位数将是第一个分区的最大元素。
public class Solution {

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        RunningMedianHeaps s = new RunningMedianHeaps();
        int n = in.nextInt();
        for(int a_i=0; a_i < n; a_i++){
            printMedian(s,in.nextInt());
        }
        in.close();       
    }

    public static void printMedian(RunningMedianHeaps s, int nextNum){
            s.addNumberInHeap(nextNum);
            System.out.printf("%.1f\n",s.getMedian());
    }
}

class RunningMedianHeaps{
    PriorityQueue<Integer> minHeap = new PriorityQueue<Integer>();
    PriorityQueue<Integer> maxHeap = new PriorityQueue<Integer>(Comparator.reverseOrder());

    public double getMedian() {

        int size = minHeap.size() + maxHeap.size();     
        if(size % 2 == 0)
            return (maxHeap.peek()+minHeap.peek())/2.0;
        return maxHeap.peek()*1.0;
    }

    private void balanceHeaps() {
        if(maxHeap.size() < minHeap.size())
        {
            maxHeap.add(minHeap.poll());
        }   
        else if(maxHeap.size() > 1+minHeap.size())
        {
            minHeap.add(maxHeap.poll());
        }
    }

    public void addNumberInHeap(int num) {
        if(maxHeap.size()==0 || num <= maxHeap.peek())
        {
            maxHeap.add(num);
        }
        else
        {
            minHeap.add(num);
        }
        balanceHeaps();
    }
}

2
对我来说,第三个Java答案对于一个C语言问题到底有多大的好处并不清楚。你应该提出一个新问题,然后在那个问题中提供你的Java答案。 - jww
逻辑在阅读这句话后崩溃:“然后从最小堆中删除顶部元素并添加到最小堆中。”。请在发布内容前至少礼貌地阅读一遍算法。 - Cyclotron3x3
7
这个算法不适用于滚动中位数,而是用于不断增加元素的中位数。对于滚动中位数,还需要从堆中删除一个元素,需要先找到该元素。 - Walter

3
值得提醒的是,有一种特殊情况具有简单的精确解:当流中的所有值都是在一个(相对)小的定义范围内的整数时。例如,假设它们必须全部位于0和1023之间。在这种情况下,只需定义一个包含1024个元素和一个计数器的数组,并清除所有这些值。针对流中的每个值增加相应的bin和计数。在流结束后,找到包含count/2最高值的bin——只需从0开始添加连续的bins即可轻松完成。使用相同的方法可以找到任意等级排序的值。(如果需要检测bin饱和度并“升级”存储bin的大小,则会有一个小的复杂性。)
这种特殊情况可能看起来很人为,但在实践中非常普遍。如果实数落在一个范围内,并且已知“足够好”的精度,它也可以被视为对实数的近似值。这适用于几乎任何一组“现实世界”对象的测量结果,例如一群人的身高或体重。不够大?对于全球所有(个体)细菌的长度或重量同样适用——假设有人能提供数据!
貌似我误读了原始文本——它似乎想要一个滑动窗口中位数,而不仅仅是长流的中位数。这种方法仍然适用于此。首先将前N个流值加载到初始窗口中,然后对于第N+1个流值,增加相应的bin,并减少与第0个流值对应的bin。在这种情况下,有必要保留最后N个值以允许减量操作,可以通过循环访问大小为N的数组来有效地完成。由于在滑动窗口的每个步骤中,中位数的位置只能变化-2,-1,0,1,2,因此不需要在每个步骤中对所有bins进行求和以找到中位数,只需根据修改了哪些侧面的bins调整“中位指针”即可。例如,如果新值和正在删除的值都低于当前中位数,则它不会发生变化(偏移= 0)。当N变得过大而无法方便地存储在内存中时,该方法将失效。

1
如果您能够按时间点引用值作为函数,那么您可以使用替换抽样来应用自助法生成置信区间内的自助法中位数值。这可能比不断将传入的值排序到数据结构中更有效地让您计算近似中位数。

0

对于那些需要在Java中运行中位数的人来说...PriorityQueue是你的好朋友。O(log N)插入,O(1)当前中位数和O(N)删除。如果您知道数据的分布,您可以做得比这更好。

public class RunningMedian {
  // Two priority queues, one of reversed order.
  PriorityQueue<Integer> lower = new PriorityQueue<Integer>(10,
          new Comparator<Integer>() {
              public int compare(Integer arg0, Integer arg1) {
                  return (arg0 < arg1) ? 1 : arg0 == arg1 ? 0 : -1;
              }
          }), higher = new PriorityQueue<Integer>();

  public void insert(Integer n) {
      if (lower.isEmpty() && higher.isEmpty())
          lower.add(n);
      else {
          if (n <= lower.peek())
              lower.add(n);
          else
              higher.add(n);
          rebalance();
      }
  }

  void rebalance() {
      if (lower.size() < higher.size() - 1)
          lower.add(higher.remove());
      else if (higher.size() < lower.size() - 1)
          higher.add(lower.remove());
  }

  public Integer getMedian() {
      if (lower.isEmpty() && higher.isEmpty())
          return null;
      else if (lower.size() == higher.size())
          return (lower.peek() + higher.peek()) / 2;
      else
          return (lower.size() < higher.size()) ? higher.peek() : lower
                  .peek();
  }

  public void remove(Integer n) {
      if (lower.remove(n) || higher.remove(n))
          rebalance();
  }
}

C++标准库的扩展中,GNU提供了有序统计树。请参见我下面的帖子。 - Leo Goodstadt
我认为你的代码没有正确地放在这里。有一些不完整的部分,比如:}), higher = new PriorityQueue<Integer>(); 或者 new PriorityQueue<Integer>(10,。我无法运行代码。 - Hengameh
@Hengameh Java用分号结束语句——换行不重要。你一定是复制错误了。 - Matthew Read
1
你应该提出一个新问题,然后在那个问题中提供你的Java答案。 - jww

0

这是一个可以在精确输出不重要时使用的函数(用于显示等)。 您需要 totalcount 和 lastmedian,以及 newvalue。

{
totalcount++;
newmedian=lastmedian+(newvalue>lastmedian?1:-1)*(lastmedian==0?newvalue: lastmedian/totalcount*2);
}

对于像页面显示时间这样的事情,可以产生相当精确的结果。

规则:输入流需要在页面显示时间的顺序上平稳,数量大(>30等),并且具有非零中位数。

例如: 页面加载时间,800个项目,10毫秒...3000毫秒,平均90毫秒,真实中位数:11毫秒

经过30次输入后,中位数误差通常<=20%(9ms..12ms),并且越来越少。 经过800次输入后,误差为+-2%。

另一个思考者提出了类似的解决方案:Median Filter Super efficient implementation


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接