从数组中获取n个最小值的最快方法

10

我需要找到一个double类型数组(我们称其为samples)中前n个最小值(排除0)。我需要在循环中多次执行此操作,因此执行速度非常关键。我尝试首先对数组进行排序,然后取前10个值(不包括0),但是,虽然Array.Sort被认为很快,但它成为了瓶颈:

const int numLowestSamples = 10;

double[] samples;

double[] lowestSamples = new double[numLowestSamples];

for (int count = 0; count < iterations; count++) // iterations typically around 2600000
{
    samples = whatever;
    Array.Sort(samples);
    lowestSamples = samples.SkipWhile(x => x == 0).Take(numLowestSamples).ToArray();
}

因此,我尝试了一种不太简洁的解决方案,首先读取前n个值,将它们排序,然后循环遍历samples中的所有其他值,检查该值是否小于已排序的lowestSamples数组中的最后一个值。如果该值较低,则用数组中的值替换它,并再次对数组进行排序。结果发现这种方法大约快了5倍:

const int numLowestSamples = 10;

double[] samples;

List<double> lowestSamples = new List<double>();

for (int count = 0; count < iterations; count++) // iterations typically around 2600000
{
    samples = whatever;

    lowestSamples.Clear();

    // Read first n values
    int i = 0;
    do
    {
        if (samples[i] > 0)
            lowestSamples.Add(samples[i]);

        i++;
    } while (lowestSamples.Count < numLowestSamples)

    // Sort the array
    lowestSamples.Sort();

    for (int j = numLowestSamples; j < samples.Count; j++) // samples.Count is typically 3600
    {
        // if value is larger than 0, but lower than last/highest value in lowestSamples
        // write value to array (replacing the last/highest value), then sort array so
        // last value in array still is the highest
        if (samples[j] > 0 && samples[j] < lowestSamples[numLowestSamples - 1])
        {
            lowestSamples[numLowestSamples - 1] = samples[j];
            lowestSamples.Sort();
        }
    }
}

虽然这个方法相对比较快,但我想挑战任何人提出更快更好的解决方案。


4
维护一个小根堆是否是一个好的解决方案呢? - ChaosPandion
ChaosPandion:你比我快了5秒钟 ;) - robbrit
如果这是一个只会被调用一次的东西,那么在代码可维护性上增加额外工作/复杂度是否值得性能收益? - Jake1164
你忘记递增i了,但我们明白了。 - Les
1
感谢所有的回答!我尝试了使用快速选择算法,但这至少在我测试的情况下比只循环一次集合慢了约3倍。至于斐波那契堆,我必须承认我不确定如何实现它。最终我通过调整算法来优化我的程序,而不是每次都对最低样本进行排序;相反,像建议的那样就地插入。此外,使用变量存储 samples[j] 的值而不是多次调用 samples[j] (就像 tumtumtum 的代码中一样) 也有所帮助。有了这些调整,我能够将执行时间减少了近一半。 - Roger Saele
6个回答

3

2

不要重复对lowestSamples进行排序,而是将样本插入到它应该在的位置:

int samplesCount = samples.Count;

for (int j = numLowestSamples; j < samplesCount; j++)
{
    double sample = samples[j];

    if (sample > 0 && sample < currentMax)
    {
        int k;

        for (k = 0; k < numLowestSamples; k++)
        {
           if (sample < lowestSamples[k])
           {
              Array.Copy(lowestSamples, k, lowestSamples, k + 1, numLowestSamples - k - 1);
              lowestSamples[k] = sample;

              break;
           }
        }

        if (k == numLowestSamples)
        {
           lowestSamples[numLowestSamples - 1] = sample;
        }

        currentMax = lowestSamples[numLowestSamples - 1];
    }
}

如果numLowestSamples需要非常大(接近于samples.count的大小),那么您可能希望使用优先队列,这可能会更快(通常插入新样本的时间复杂度为O(logn),而不是O(n/2),其中n是numLowestSamples)。优先队列能够有效地插入新值,并在O(logn)时间内删除最大值。

当numLowestSamples为10时,实际上没有必要使用它--特别是因为您只处理双精度浮点数,而不是复杂的数据结构。对于堆和小的numLowestSamples,分配堆节点的开销(大多数优先队列使用堆)可能比任何搜索/插入效率提高都更大(测试很重要)。


通过删除k for循环并使用Array.BinarySearch,您可能会挤出更多的性能。如果Array.BinarySearch(k)的返回值为0或正数,则忽略(找到完全匹配)。如果它是负数,则使k =〜k,并像往常一样执行Array.Copy。可能不会有太大的区别,因为log2(10)不会比O(10/2)好多少。 - tumtumtum

2
理想情况下,您只希望对集合进行一次遍历,因此您的解决方案非常巧妙。但是,每次插入时,您都会将整个子列表重新排序,而实际上您只需要提升它前面的数字。然而,对10个元素进行排序几乎可以忽略不计,并且增强此功能并不能为您带来太多好处。最坏的情况(在浪费性能方面)是从开头开始有9个最低数字,因此每次找到一个小于“lowestSamples[numLowestSamples-1]” 的后续数字时,您将对已经排序的列表进行排序(这是QuickSort的最坏情况)。
总之,由于您使用的数字很少,因此在使用托管语言进行此操作的开销方面,您不太可能有很大的数学改进空间。
祝贺您拥有这个酷炫的算法!

2

两种不同的想法:

  1. 不要对数组进行排序,只需在上面执行一次单个插入排序。 你已经知道新添加的项目是唯一未排序的项目,因此利用这一点。
  2. 看看堆排序。 它构建一个二进制最大堆(如果要从小到大排序),然后开始通过交换索引0处的最大元素和仍然是堆的最后一个元素来删除堆中的元素。现在,如果您假装按从最大到最小元素的顺序对数组进行排序,则可以在排序10个元素后停止排序。 数组末尾的10个元素将是最小的,剩余的数组仍然是数组表示中的二叉堆。 我不确定如何与维基百科上的基于快速排序的选择算法相比较。构建堆将始终针对整个数组执行,无论要选择多少个元素。

2

我认为你的想法是正确的。也就是说,通过一次遍历并保留最小大小的排序数据结构通常是最快的。你对此的性能改进是优化。

你的优化包括: 1)每次遍历时都要对结果进行排序。这对于小规模的情况可能是最快的,但对于较大的数据集来说不是最快的。考虑使用两种算法,一种用于给定阈值以下的数据,另一种(例如堆排序)用于超过该阈值的数据。 2)跟踪必须从最小集合中删除的任何值(你目前通过查看最后一个元素来实现)。你可以跳过插入和排序任何大于或等于被踢出的任何值的值。


1

我认为你可能想尝试维护一个最小堆,并测量性能差异。这里有一种数据结构,叫做斐波那契堆,我一直在研究它。它可能需要一些修改,但至少你可以测试我的假设。

public sealed class FibonacciHeap<TKey, TValue>
{
    readonly List<Node> _root = new List<Node>();
    int _count;
    Node _min;

    public void Push(TKey key, TValue value)
    {
        Insert(new Node {
            Key = key,
            Value = value
        });
    }       

    public KeyValuePair<TKey, TValue> Peek()
    {
        if (_min == null)
            throw new InvalidOperationException();
        return new KeyValuePair<TKey,TValue>(_min.Key, _min.Value);
    }       

    public KeyValuePair<TKey, TValue> Pop()
    {
        if (_min == null)
            throw new InvalidOperationException();
        var min = ExtractMin();
        return new KeyValuePair<TKey,TValue>(min.Key, min.Value);
    }

    void Insert(Node node)
    {
        _count++;
        _root.Add(node);
        if (_min == null)
        {
            _min = node;
        }
        else if (Comparer<TKey>.Default.Compare(node.Key, _min.Key) < 0)
        {
            _min = node;
        }
    }

    Node ExtractMin()
    {
        var result = _min;
        if (result == null)
            return null;
        foreach (var child in result.Children)
        {
            child.Parent = null;
            _root.Add(child);
        }
        _root.Remove(result);
        if (_root.Count == 0)
        {
            _min = null;
        }
        else
        {
            _min = _root[0];
            Consolidate();
        }
        _count--;
        return result;
    }

    void Consolidate()
    {
        var a = new Node[UpperBound()];
        for (int i = 0; i < _root.Count; i++)
        {
            var x = _root[i];
            var d = x.Children.Count;
            while (true)
            {   
                var y = a[d];
                if (y == null)
                    break;                  
                if (Comparer<TKey>.Default.Compare(x.Key, y.Key) > 0)
                {
                    var t = x;
                    x = y;
                    y = t;
                }
                _root.Remove(y);
                i--;
                x.AddChild(y);
                y.Mark = false;
                a[d] = null;
                d++;
            }
            a[d] = x;
        }
        _min = null;
        for (int i = 0; i < a.Length; i++)
        {
            var n = a[i];
            if (n == null)
                continue;
            if (_min == null)
            {
                _root.Clear();
                _min = n;
            }
            else
            {
                if (Comparer<TKey>.Default.Compare(n.Key, _min.Key) < 0)
                {
                    _min = n;
                }
            }
            _root.Add(n);
        }
    }

    int UpperBound()
    {
        return (int)Math.Floor(Math.Log(_count, (1.0 + Math.Sqrt(5)) / 2.0)) + 1;
    }

    class Node
    {
        public TKey Key;
        public TValue Value;
        public Node Parent;
        public List<Node> Children = new List<Node>();
        public bool Mark;

        public void AddChild(Node child)
        {
            child.Parent = this;
            Children.Add(child);
        }

        public override string ToString()
        {
            return string.Format("({0},{1})", Key, Value);
        }
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接