使用STL容器计算中位数时应采取正确的方法是什么?

63
假设我需要从1000000个随机数值中检索中位数。
如果不使用任何东西,除了std::list,我就没有(内置的)方法对序列进行排序以计算中位数。
如果使用std::list,我无法随机访问值来检索已排序序列的中间(中位数)。
是更好地实现自己的排序并使用例如std::vector,还是更好地使用std::list并使用std::list :: iterator进行循环遍历到中位值?后者似乎更少开销,但感觉更丑陋。
还是有更多、更好的选择吗?
10个回答

128

任何支持随机访问的容器(例如std::vector)都可以使用标准的std::sort算法进行排序,该算法在<algorithm>头文件中可用。

要查找中位数,最好使用std::nth_element;它会对足够多的元素进行排序,以将一个选择的元素放在正确的位置,但并不完全对容器进行排序。因此,您可以像这样查找中位数:

int median(vector<int> &v)
{
    size_t n = v.size() / 2;
    nth_element(v.begin(), v.begin()+n, v.end());
    return v[n];
}

3
哦,我没意识到nth_element的存在,看来我在回答中重新实现了它... - ephemient
9
需要注意的是,nth_element 以不可预测的方式修改向量!如果有必要,您可能需要对索引向量进行排序。 - Matthieu M.
61
如果物品数量是偶数,那么中位数就是中间两个数的平均值。 - sje397
6
@sje397 说得没错,当向量包含偶数个元素时,这个算法有一半的概率是错误的。调用nth_element函数两次(用于获取中间的两个元素)比调用一次sort函数更昂贵吗? - Agostino
1
@Fabian,partial_sort仍然是O(N*log(N)),而nth_element是O(N)(如果执行两次,则为O(2N),仍然是线性的),因此我预计随着N的增加,nth_element会更快,但我还没有进行任何分析来确认这一点。 - ClydeTheGhost
显示剩余3条评论

46

中位数比Mike Seymour的答案更加复杂。中位数取决于样本中的项数是奇数还是偶数。如果项数为偶数,则中位数是中间两个项的平均值。这意味着整数列表的中位数可以是一个分数。最后,空列表的中位数未定义。以下是通过我的基本测试用例的代码:

///Represents the exception for taking the median of an empty list
class median_of_empty_list_exception:public std::exception{
  virtual const char* what() const throw() {
    return "Attempt to take the median of an empty list of numbers.  "
      "The median of an empty list is undefined.";
  }
};

///Return the median of a sequence of numbers defined by the random
///access iterators begin and end.  The sequence must not be empty
///(median is undefined for an empty set).
///
///The numbers must be convertible to double.
template<class RandAccessIter>
double median(RandAccessIter begin, RandAccessIter end) 
  if(begin == end){ throw median_of_empty_list_exception(); }
  std::size_t size = end - begin;
  std::size_t middleIdx = size/2;
  RandAccessIter target = begin + middleIdx;
  std::nth_element(begin, target, end);

  if(size % 2 != 0){ //Odd number of elements
    return *target;
  }else{            //Even number of elements
    double a = *target;
    RandAccessIter targetNeighbor= target-1;
    std::nth_element(begin, targetNeighbor, end);
    return (a+*targetNeighbor)/2.0;
  }
}

23
我知道这是很久以前的事情了,但因为我刚在 Google 上找到这个:std::nth_element 实际上还保证任何前面的元素都小于等于目标元素,任何后面的元素都大于等于目标元素。所以你可以直接使用 targetNeighbor = std::min_element(begin, target) 跳过部分排序,这可能会稍微快一点。(nth_element 的平均时间复杂度是线性的,而 min_element 显然也是线性的。)即使你更喜欢再次使用 nth_element,只需要执行 nth_element(begin, targetNeighbor, target) 就能等效且可能更快地完成任务。 - Danica
12
我理解你的意思是在这种情况下 targetNeighbor = std::max_element(begin, target) 是正确的吗? - izak
@Dougal 我知道这条评论是很久以前的了 ;),但我不知道你的方法应该如何工作,你确定这会给出正确的结果吗? - 463035818_is_not_a_number
2
@tobi303 你的“永远”是我的两倍长。 :) 而且,确实应该这样做:调用std::nth_element后,序列就像[小于目标值, 目标值, 大于目标值]一样。因此,您知道第target-1个元素在数组的前半部分,您只需要找到target之前的元素的最大值即可得到中位数。 - Danica
1
@AlexisWilke 为什么这些会是特殊情况?如果只有1个元素,则 size=1middleIdx=0target=begin,因此 nth_element 是无操作的。如果有2个元素,则 size=2middleIdx=1target=begin+1=end-1,因此第一个 nth_element 被调用时使用 (begin, end-1, end),第二个 nth_element 使用 (begin,begin,end)。在任何地方,target 都不等于 end - Ruslan
显示剩余2条评论

22

使用STL的nth_element算法(摊销O(N))和max_element算法(O(n)),该算法有效地处理偶数大小和奇数大小的输入。请注意,nth_element还有另一个保证的副作用,即在n之前的所有元素都保证小于v[n],但不一定排序。

//post-condition: After returning, the elements in v may be reordered and the resulting order is implementation defined.
double median(vector<double> &v)
{
  if(v.empty()) {
    return 0.0;
  }
  auto n = v.size() / 2;
  nth_element(v.begin(), v.begin()+n, v.end());
  auto med = v[n];
  if(!(v.size() & 1)) { //If the set size is even
    auto max_it = max_element(v.begin(), v.begin()+n);
    med = (*max_it + med) / 2.0;
  }
  return med;    
}

1
我喜欢你的答案,但当向量为空时返回零不适合我的应用程序,我更希望在向量为空时抛出异常。 - Alessandro Jacopson

12

这里是迈克·西摩答案更完整的版本:

// Could use pass by copy to avoid changing vector
double median(std::vector<int> &v)
{
  size_t n = v.size() / 2;
  std::nth_element(v.begin(), v.begin()+n, v.end());
  int vn = v[n];
  if(v.size()%2 == 1)
  {
    return vn;
  }else
  {
    std::nth_element(v.begin(), v.begin()+n-1, v.end());
    return 0.5*(vn+v[n-1]);
  }
}

它可以处理奇数或偶数长度的输入。


1
对于传递副本,您是不是想在输入上删除引用(&)? - chappjc
1
我只是想作为一条注释来说明,一个人可以使用传递复印件的方式,这种情况下,是的,应该删除& - Alec Jacobson
这个版本中有一个bug。在再次执行nth_element之前,您需要提取v[n],因为第二轮后,v[n]可能包含不同的值。 - Matthew Fioravante
1
@MatthewFioravante,我明白了。根据文档,我猜nth_element不需要稳定。(已经相应地修改了我的答案)。 - Alec Jacobson
3
不要再调用nth_element了,直接从v[0]v[n]迭代并确定其中的最大值会更有效率吧? - bluenote10

8

综合此帖子中的所有见解,我最终得到了这个例程。它适用于任何STL容器或任何提供输入迭代器的类,并处理奇数和偶数大小的容器。它还在容器的副本上工作,以不修改原始内容。

template <typename T = double, typename C>
inline const T median(const C &the_container)
{
    std::vector<T> tmp_array(std::begin(the_container), 
                             std::end(the_container));
    size_t n = tmp_array.size() / 2;
    std::nth_element(tmp_array.begin(), tmp_array.begin() + n, tmp_array.end());

    if(tmp_array.size() % 2){ return tmp_array[n]; }
    else
    {
        // even sized vector -> average the two middle values
        auto max_it = std::max_element(tmp_array.begin(), tmp_array.begin() + n);
        return (*max_it + tmp_array[n]) / 2.0;
    }
}

正如Matthew Fioravante在https://dev59.com/IHI-5IYBdhLWcg3wtaxq#D0cQoYgBc1ULPQZFJdsq中所提到的,“在再次使用nth_element之前,您需要先提取v[n],因为第二轮后v[n]可能包含不同的值。” 因此,让med = tmp_array [n],然后正确的返回行是:return (*max_it + med) / 2.0; - trig-ger
5
这个解决方案中只使用了一次@trig-ger nth_element。这不是问题。 - denver
static_assert(std::is_same_v<typename C::value_type, T>, "容器类型和元素类型不匹配") 或许可以? - einpoklum

4

您可以使用库函数std::sortstd::vector进行排序。

std::vector<int> vec;
// ... fill vector with stuff
std::sort(vec.begin(), vec.end());

2
存在一种线性时间选择算法。下面的代码只适用于容器具有随机访问迭代器的情况,但可以进行修改以使其在没有随机访问迭代器的情况下工作——您只需要更加小心地避免使用像end-beginiter+n这样的快捷方式即可。
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <vector>

template<class A, class C = std::less<typename A::value_type> >
class LinearTimeSelect {
public:
    LinearTimeSelect(const A &things) : things(things) {}
    typename A::value_type nth(int n) {
        return nth(n, things.begin(), things.end());
    }
private:
    static typename A::value_type nth(int n,
            typename A::iterator begin, typename A::iterator end) {
        int size = end - begin;
        if (size <= 5) {
            std::sort(begin, end, C());
            return begin[n];
        }
        typename A::iterator walk(begin), skip(begin);
#ifdef RANDOM // randomized algorithm, average linear-time
        typename A::value_type pivot = begin[std::rand() % size];
#else // guaranteed linear-time, but usually slower in practice
        while (end - skip >= 5) {
            std::sort(skip, skip + 5);
            std::iter_swap(walk++, skip + 2);
            skip += 5;
        }
        while (skip != end) std::iter_swap(walk++, skip++);
        typename A::value_type pivot = nth((walk - begin) / 2, begin, walk);
#endif
        for (walk = skip = begin, size = 0; skip != end; ++skip)
            if (C()(*skip, pivot)) std::iter_swap(walk++, skip), ++size;
        if (size <= n) return nth(n - size, walk, end);
        else return nth(n, begin, walk);
    }
    A things;
};

int main(int argc, char **argv) {
    std::vector<int> seq;
    {
        int i = 32;
        std::istringstream(argc > 1 ? argv[1] : "") >> i;
        while (i--) seq.push_back(i);
    }
    std::random_shuffle(seq.begin(), seq.end());
    std::cout << "unordered: ";
    for (std::vector<int>::iterator i = seq.begin(); i != seq.end(); ++i)
        std::cout << *i << " ";
    LinearTimeSelect<std::vector<int> > alg(seq);
    std::cout << std::endl << "linear-time medians: "
        << alg.nth((seq.size()-1) / 2) << ", " << alg.nth(seq.size() / 2);
    std::sort(seq.begin(), seq.end());
    std::cout << std::endl << "medians by sorting: "
        << seq[(seq.size()-1) / 2] << ", " << seq[seq.size() / 2] << std::endl;
    return 0;
}

2

这里是一个考虑了@MatthieuM建议的答案,即不修改输入向量。它使用单个部分排序(在索引向量上)来处理偶数和奇数基数范围,而空范围则通过向量的at方法抛出异常来处理:

double median(vector<int> const& v)
{
    bool isEven = !(v.size() % 2); 
    size_t n    = v.size() / 2;

    vector<size_t> vi(v.size()); 
    iota(vi.begin(), vi.end(), 0); 

    partial_sort(begin(vi), vi.begin() + n + 1, end(vi), 
        [&](size_t lhs, size_t rhs) { return v[lhs] < v[rhs]; }); 

    return isEven ? 0.5 * (v[vi.at(n-1)] + v[vi.at(n)]) : v[vi.at(n)];
}

Demo


1

Armadillo有一个实现看起来像这个答案https://dev59.com/IHI-5IYBdhLWcg3wtaxq#34077478中的,作者是https://stackoverflow.com/users/2608582/matthew-fioravante

它使用了一次nth_element和一次max_element的调用,代码在这里: https://gitlab.com/conradsnicta/armadillo-code/-/blob/9.900.x/include/armadillo_bits/op_median_meat.hpp#L380

//! find the median value of a std::vector (contents is modified)
template<typename eT>
inline 
eT
op_median::direct_median(std::vector<eT>& X)
  {
  arma_extra_debug_sigprint();
  
  const uword n_elem = uword(X.size());
  const uword half   = n_elem/2;
  
  typename std::vector<eT>::iterator first    = X.begin();
  typename std::vector<eT>::iterator nth      = first + half;
  typename std::vector<eT>::iterator pastlast = X.end();
  
  std::nth_element(first, nth, pastlast);
  
  if((n_elem % 2) == 0)  // even number of elements
    {
    typename std::vector<eT>::iterator start   = X.begin();
    typename std::vector<eT>::iterator pastend = start + half;
    
    const eT val1 = (*nth);
    const eT val2 = (*(std::max_element(start, pastend)));
    
    return op_mean::robust_mean(val1, val2);
    }
  else  // odd number of elements
    {
    return (*nth);
    }
  }

0
you can use this approch. It also takes care of sliding window.
Here days are no of trailing elements for which we want to find median and this makes sure the original container is not changed


#include<bits/stdc++.h>

using namespace std;

int findMedian(vector<int> arr, vector<int> brr, int d, int i)
{
    int x,y;
    x= i-d;
    y=d;
    brr.assign(arr.begin()+x, arr.begin()+x+y);


    sort(brr.begin(), brr.end());

    if(d%2==0)
    {
        return((brr[d/2]+brr[d/2 -1]));
    }

    else
    {
        return (2*brr[d/2]);
    }

    // for (int i = 0; i < brr.size(); ++i)
    // {
    //     cout<<brr[i]<<" ";
    // }

    return 0;

}

int main()
{
    int n;
    int days;
    int input;
    int median;
    int count=0;

    cin>>n>>days;

    vector<int> arr;
    vector<int> brr;

    for (int i = 0; i < n; ++i)
    {
        cin>>input;
        arr.push_back(input);
    }

    for (int i = days; i < n; ++i)
    {
        median=findMedian(arr,brr, days, i);

        
    }



    return 0;
}

1
请在添加代码片段时尽可能添加解释。 - Yunus Temurlenk

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接