LINQ计算SortedList<dateTime,double>的移动平均值

17
我有一个时间序列,形式为SortedList<dateTime,double>。我想计算这个序列的移动平均值。我可以使用简单的for循环来做到这一点。我想知道是否有更好的方法使用linq来实现。
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace ConsoleApplication1
{
    class Program
    {
        static void Main(string[] args)
        {
            var mySeries = new SortedList<DateTime, double>();
            mySeries.Add(new DateTime(2011, 01, 1), 10);
            mySeries.Add(new DateTime(2011, 01, 2), 25);
            mySeries.Add(new DateTime(2011, 01, 3), 30);
            mySeries.Add(new DateTime(2011, 01, 4), 45);
            mySeries.Add(new DateTime(2011, 01, 5), 50);
            mySeries.Add(new DateTime(2011, 01, 6), 65);

            var calcs = new calculations();
            var avg = calcs.MovingAverage(mySeries, 3);
            foreach (var item in avg)
            {
                Console.WriteLine("{0} {1}", item.Key, item.Value);                
            }
        }
    }
    class calculations
    {
        public SortedList<DateTime, double> MovingAverage(SortedList<DateTime, double> series, int period)
        {
            var result = new SortedList<DateTime, double>();

            for (int i = 0; i < series.Count(); i++)
            {
                if (i >= period - 1)
                {
                    double total = 0;
                    for (int x = i; x > (i - period); x--)
                        total += series.Values[x];
                    double average = total / period;
                    result.Add(series.Keys[i], average);  
                }

            }
            return result;
        }
    }
}

1
在转换到 LINQ 之前,我会先进行测试。通常,手写简单的 for 循环在性能上会胜过 LINQ。 - Mike M.
1
经过测试,手写的非Linq解决方案确实是更好(读取更快)的解决方案。 - Andre P.
8个回答

19

为了达到O(n)的渐进性能(与手写解决方案相似),可以使用Aggregate函数,就像以下代码:

series.Skip(period-1).Aggregate(
  new {
    Result = new SortedList<DateTime, double>(), 
    Working = List<double>(series.Take(period-1).Select(item => item.Value))
  }, 
  (list, item)=>{
     list.Working.Add(item.Value); 
     list.Result.Add(item.Key, list.Working.Average()); 
     list.Working.RemoveAt(0);
     return list;
  }
).Result;
累加值(实现为匿名类型)包含两个字段:Result 包含到目前为止构建的结果列表,Working 包含最后 period-1 个元素。聚合函数将当前值添加到 Working 列表中,构建当前平均值并将其添加到结果中,然后从工作列表中删除第一个(即最旧的)值。
"种子"(即累加的起始值)是通过将前 period-1 个元素放入 Working 中,并将 Result 初始化为空列表来构建的。
因此,聚合从元素 period 开始(通过在开始时跳过 (period-1) 元素)。
顺便提一下,在功能编程中,这是聚合(或 fold)函数的典型使用模式。
两点说明:
该解决方案在“功能”上不是干净的,因为每一步都重复使用相同的列表对象(WorkingResult)。我不确定如果某些未来的编译器尝试自动并行化聚合函数是否会导致问题(另一方面,我也不确定这是否可能...)。纯粹的功能解决方案应该在每一步中"创建"新的列表。
还要注意的是,C#缺乏强大的列表表达式。在某些假想的Python-C#混合伪代码中,可以编写聚合函数如下:
(list, item)=>
  new {
    Result = list.Result + [(item.Key, (list.Working+[item.Value]).Average())], 
    Working=list.Working[1::]+[item.Value]
  }

在我谦虚的意见中,这样会更加优雅些 :)


13

要以最有效的方式使用LINQ计算移动平均值,您不应该使用LINQ!

相反,我建议创建一个帮助类以最有效的方式计算移动平均值(使用循环缓冲区和因果移动平均滤波器),然后创建一个扩展方法使其可供LINQ使用。

首先是移动平均值。

public class MovingAverage
{
    private readonly int _length;
    private int _circIndex = -1;
    private bool _filled;
    private double _current = double.NaN;
    private readonly double _oneOverLength;
    private readonly double[] _circularBuffer;
    private double _total;

    public MovingAverage(int length)
    {
        _length = length;
        _oneOverLength = 1.0 / length;
        _circularBuffer = new double[length];
    }       

    public MovingAverage Update(double value)
    {
        double lostValue = _circularBuffer[_circIndex];
        _circularBuffer[_circIndex] = value;

        // Maintain totals for Push function
        _total += value;
        _total -= lostValue;

        // If not yet filled, just return. Current value should be double.NaN
        if (!_filled)
        {
            _current = double.NaN;
            return this;
        }

        // Compute the average
        double average = 0.0;
        for (int i = 0; i < _circularBuffer.Length; i++)
        {
            average += _circularBuffer[i];
        }

        _current = average * _oneOverLength;

        return this;
    }

    public MovingAverage Push(double value)
    {
        // Apply the circular buffer
        if (++_circIndex == _length)
        {
            _circIndex = 0;
        }

        double lostValue = _circularBuffer[_circIndex];
        _circularBuffer[_circIndex] = value;

        // Compute the average
        _total += value;
        _total -= lostValue;

        // If not yet filled, just return. Current value should be double.NaN
        if (!_filled && _circIndex != _length - 1)
        {
            _current = double.NaN;
            return this;
        }
        else
        {
            // Set a flag to indicate this is the first time the buffer has been filled
            _filled = true;
        }

        _current = _total * _oneOverLength;

        return this;
    }

    public int Length { get { return _length; } }
    public double Current { get { return _current; } }
}

这个类提供了一个非常快速和轻量级的MovingAverage过滤器实现。它创建了一个长度为N的循环缓冲区,并在每个添加的数据点上计算一个加法、一个减法和一个乘法,而不是对于暴力实现的每个点进行N次乘加运算。
接下来,让我们将其转换为LINQ!
internal static class MovingAverageExtensions
{
    public static IEnumerable<double> MovingAverage<T>(this IEnumerable<T> inputStream, Func<T, double> selector, int period)
    {
        var ma = new MovingAverage(period);
        foreach (var item in inputStream)
        {
            ma.Push(selector(item));
            yield return ma.Current;
        }
    }

    public static IEnumerable<double> MovingAverage(this IEnumerable<double> inputStream, int period)
    {
        var ma = new MovingAverage(period);
        foreach (var item in inputStream)
        {
            ma.Push(item);
            yield return ma.Current;
        }
    }
}

上述扩展方法包装了 MovingAverage 类,并允许在 IEnumerable 流中插入。
现在来使用它吧!
int period = 50;

// Simply filtering a list of doubles
IEnumerable<double> inputDoubles;
IEnumerable<double> outputDoubles = inputDoubles.MovingAverage(period);   

// Or, use a selector to filter T into a list of doubles
IEnumerable<Point> inputPoints; // assuming you have initialised this
IEnumerable<double> smoothedYValues = inputPoints.MovingAverage(pt => pt.Y, period);

谢谢,强大的for循环嘲笑.Zip.Scan.Select(Tuple)方法! - Dr. Andrew Burnett-Thompson
1
几年后,但实际上是一种可靠的方法。 - Eric

7

你已经有一个回答展示了如何使用LINQ,但老实说,在这里我不会使用LINQ,因为它很可能与你当前的解决方案相比表现较差,并且你现有的代码已经很清晰。

然而,你可以在每一步中保持一个运行总数并在每次迭代时进行调整,而不是在每个步骤上计算前面period元素的总和。也就是说,将这个:

total = 0;
for (int x = i; x > (i - period); x--)
    total += series.Values[x];

转换为:

if (i >= period) {
    total -= series.Values[i - period];
}
total += series.Values[i];

这意味着您的代码执行时间不会因period的大小而改变。

这并没有真正回答问题。OP想知道如何在Linq中实现它。 - Brian Genisio
3
在我看来,“不使用LINQ”是对这个问题的一个有效答案。LINQ非常好用,但在这里使用它是错误的工具。 - Mark Byers
1
实际上,我只是想知道如何做得更好。话虽如此,以后我可能会直接从SQL数据库中提取这些值。在这种情况下,全LINQ解决方案可能更好。我将对它们进行基准测试,以查看哪个更快。 - Andre P.

7
这是一个HTML代码块。
double total = 0;
for (int x = i; x > (i - period); x--)
    total += series.Values[x];
double average = total / period;

可以重写为:

double average = series.Values.Skip(i - period + 1).Take(period).Sum() / period;

你的方法可能看起来像这样:

series.Skip(period - 1)
    .Select((item, index) =>
        new 
        {
            item.Key,            
            series.Values.Skip(index).Take(period).Sum() / period
        });

正如你所看到的,linq非常具有表现力。我建议从一些教程开始学习,比如介绍LINQ101个LINQ示例


3
注意算法的运行时间为*O(n^2),因为每一步都需要跳过越来越多的元素(据我所知,Skip(i)必须调用IEnumerator.MoveNext i次)。请参阅我的回复,以获取O(n)*时间复杂度的解决方案......(我刚刚注意到下面OPs的评论,他/她可能会在未来从SQL DB中获取值。 在这种情况下,我强烈反对使用此解决方案!) - MartinStettner
1
@Andre 欢迎你。 @MartinStettner 是的,你说得对。我尽量写出最优雅的解决方案,而不是最高效的... - Branimir

3
要以更具功能性的方式做到这一点,您需要一个Scan方法,在Rx中存在而不在LINQ中。

让我们看看如果我们有一个扫描方法,它会是什么样子

var delta = 3;
var series = new [] {1.1, 2.5, 3.8, 4.8, 5.9, 6.1, 7.6};

var seed = series.Take(delta).Average();
var smas = series
    .Skip(delta)
    .Zip(series, Tuple.Create)
    .Scan(seed, (sma, values)=>sma - (values.Item2/delta) + (values.Item1/delta));
smas = Enumerable.Repeat(0.0, delta-1).Concat(new[]{seed}).Concat(smas);

这里是扫描方法,取自并调整自这里

public static IEnumerable<TAccumulate> Scan<TSource, TAccumulate>(
    this IEnumerable<TSource> source,
    TAccumulate seed,
    Func<TAccumulate, TSource, TAccumulate> accumulator
)
{
    if (source == null) throw new ArgumentNullException("source");
    if (seed == null) throw new ArgumentNullException("seed");
    if (accumulator == null) throw new ArgumentNullException("accumulator");

    using (var i = source.GetEnumerator())
    {
        if (!i.MoveNext())
        {
            throw new InvalidOperationException("Sequence contains no elements");
        }
        var acc = accumulator(seed, i.Current);

        while (i.MoveNext())
        {
            yield return acc;
            acc = accumulator(acc, i.Current);
        }
        yield return acc;
    }
}

这种方法的性能应该比暴力方法更好,因为我们使用了一个运行总数来计算SMA。
发生了什么?
首先,我们需要计算第一个周期,我们在这里称之为seed。然后,每个后续值都是从累积的种子值计算出来的。为此,我们需要旧值(即t-delta)和最新值,将它们一起压缩成序列,一次从开头开始,一次按delta移位。
最后,我们通过添加第一个周期的长度的零和添加初始种子值来进行一些清理。

刚看到这个。非常有趣!得试一下看看它是否能改进 C# 的 i 循环。 - Andre P.
@AndreP。除了比暴力算法更高效外,这些值是以懒惰的方式计算的。因此,假设您有200k个值,但只写smas.Take(1000),它只会计算前1000个移动平均值。 - lukebuehler
在阅读问题(而不是所有答案)后,我也想出了同样的东西(尽管我把函数称为“AggregateSeq”)。 - James Curran

2
另一个选项是使用MoreLINQWindowed方法,这可以极大地简化代码:
var averaged = mySeries.Windowed(period).Select(window => window.Average(keyValuePair => keyValuePair.Value));

0

我使用这段代码来计算SMA:

private void calculateSimpleMA(decimal[] values, out decimal[] buffer)
{
    int period = values.Count();     // gets Period (assuming Period=Values-Array-Size)
    buffer = new decimal[period];    // initializes buffer array
    var sma = SMA(period);           // gets SMA function
    for (int i = 0; i < period; i++)
        buffer[i] = sma(values[i]);  // fills buffer with SMA calculation
}

static Func<decimal, decimal> SMA(int p)
{
    Queue<decimal> s = new Queue<decimal>(p);
    return (x) =>
    {
        if (s.Count >= p)
        {
            s.Dequeue();
        }
        s.Enqueue(x);
        return s.Average();
    };
}

0

这是一个扩展方法:

public static IEnumerable<double> MovingAverage(this IEnumerable<double> source, int period)
{
    if (source is null)
    {
        throw new ArgumentNullException(nameof(source));
    }

    if (period < 1)
    {
        throw new ArgumentOutOfRangeException(nameof(period));
    }

    return Core();

    IEnumerable<double> Core()
    {
        var sum = 0.0;
        var buffer = new double[period];
        var n = 0;
        foreach (var x in source)
        {
            n++;
            sum += x;
            var index = n % period;
            if (n >= period)
            {
                sum -= buffer[index];
                yield return sum / period;
            }

            buffer[index] = x;
        }
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接