Scala, Eratosthenes:有没有一种简单的方法可以用迭代替换流?

9

我写了一个函数,可以无限生成质数(维基百科:埃拉托斯特尼筛法),使用流式处理。它返回一个流,但它也在内部合并质数倍数的流来标记即将出现的合成数。如果我自己这么说的话,定义是简明、功能性强、优雅且易于理解的:

def primes(): Stream[Int] = {
  def merge(a: Stream[Int], b: Stream[Int]): Stream[Int] = {
    def next = a.head min b.head
    Stream.cons(next, merge(if (a.head == next) a.tail else a,
                            if (b.head == next) b.tail else b))
  }
  def test(n: Int, compositeStream: Stream[Int]): Stream[Int] = {
    if (n == compositeStream.head) test(n+1, compositeStream.tail)
    else Stream.cons(n, test(n+1, merge(compositeStream, Stream.from(n*n, n))))
  }
  test(2, Stream.from(4, 2))
}

但是,当我尝试生成第1000个质数时,我会收到"java.lang.OutOfMemoryError: GC overhead limit exceeded"的错误提示。

我有一个替代方案,它返回一个质数迭代器,并在内部使用元组(多个,用于生成多个的质数)的优先队列来标记即将出现的合数。它运行良好,但需要两倍的代码量,而且基本上我必须从头开始重新开始:

import scala.collection.mutable.PriorityQueue
def primes(): Iterator[Int] = {
  // Tuple (composite, prime) is used to generate a primes multiples
  object CompositeGeneratorOrdering extends Ordering[(Long, Int)] {
    def compare(a: (Long, Int), b: (Long, Int)) = b._1 compare a._1
  }
  var n = 2;
  val composites = PriorityQueue(((n*n).toLong, n))(CompositeGeneratorOrdering)
  def advance = {
    while (n == composites.head._1) { // n is composite
      while (n == composites.head._1) { // duplicate composites
        val (multiple, prime) = composites.dequeue
        composites.enqueue((multiple + prime, prime))
      }
      n += 1
    }
    assert(n < composites.head._1)
    val prime = n
    n += 1
    composites.enqueue((prime.toLong * prime.toLong, prime))
    prime
  }
  Iterator.continually(advance)
}

有没有一种简单的方法将使用流的代码翻译成使用迭代器的代码?或者有没有一种简单的方法使我的第一次尝试更加内存高效?
以流的方式思考更容易,我宁愿从那里开始,然后在必要时调整我的代码。

请注意,我的第一段代码在第10,000个质数时就会出现整数溢出问题。 - stewSquared
4个回答

9

我猜这是当前Stream实现中的一个错误。

primes().drop(999).head运行良好:

primes().drop(999).head
// Int = 7919

如果像这样存储 Stream ,您将会遇到 OutOfMemoryError 问题:

val prs = primes()

prs.drop(999).head
// Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded

这里的问题在于Cons实现:它不仅包含计算出的tail,还包括一个计算该tail的函数。即使tail已经计算出来了,该函数也不再需要,但仍然存在。

在这种情况下,函数非常耗费资源,即使只存储1000个函数也会导致OutOfMemoryError

我们必须想办法去掉这些函数。

直觉上的修复方法失败了:

val prs = primes().iterator.toStream

prs.drop(999).head
// Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded

使用流上的迭代器iterator,你可以得到StreamIterator,通过StreamIterator#toStream,你将获得初始的庞大Stream

解决方法

因此,我们需要手动进行转换:

def toNewStream[T](i: Iterator[T]): Stream[T] =
  if (i.hasNext) Stream.cons(i.next, toNewStream(i))
  else Stream.empty

val prs = toNewStream(primes().iterator)
// Stream[Int] = Stream(2, ?)

prs.drop(999).head
// Int = 7919

1
谢谢!我喜欢这个解决方法,或者说,我喜欢它能够保留我的实现。但是,我真的不理解它。我查看了Stream的源代码,但我不明白它有什么问题。你能详细说明一下吗? - stewSquared
1
@stewSquared:这里有一个修复。使用tlVal = tltl应该被存储在一个字段中。使用private[this] var tlFunc: () => Stream[A] = tl _tl在构造函数中被使用。 - senia

8
在你的第一段代码中,你应该推迟合并操作,直到在候选项中看到一个质数的平方。这将大大减少使用的流的数量,从而极大地改善了内存使用问题。要得到第1000个质数,即7919,我们只需要考虑不超过其平方根88的质数。这只需要考虑23个质数/它们的倍数流,而不是999(如果我们从一开始就忽略偶数,则为22)。对于第10000个质数,只需要考虑66个倍数流,而不是9999个。对于第100000个质数,只需要考虑189个。
诀窍是通过递归调用将正在消耗的质数与正在生产的质数分开:
def primes(): Stream[Int] = {
  def merge(a: Stream[Int], b: Stream[Int]): Stream[Int] = {
    def next = a.head min b.head
    Stream.cons(next, merge(if (a.head == next) a.tail else a,
                            if (b.head == next) b.tail else b))
  }
  def test(n: Int, q: Int, 
                   compositeStream: Stream[Int], 
                   primesStream: Stream[Int]): Stream[Int] = {
    if (n == q) test(n+2, primesStream.tail.head*primesStream.tail.head,
                          merge(compositeStream, 
                                Stream.from(q, 2*primesStream.head).tail),
                          primesStream.tail)
    else if (n == compositeStream.head) test(n+2, q, compositeStream.tail,
                                                     primesStream)
    else Stream.cons(n, test(n+2, q, compositeStream, primesStream))
  }
  Stream.cons(2, Stream.cons(3, Stream.cons(5, 
     test(7, 25, Stream.from(9, 6), primes().tail.tail))))
}

作为额外的奖励,不需要将质数的平方存储为Long。这也会更快,具有更好的算法复杂性(时间和空间),因为它避免了许多不必要的工作。Ideone测试显示它在生成高达n = 80,000个质数时以大约~n1.5..1.6经验增长率运行。
这里仍然存在一个算法问题:创建的结构仍然是线性向左倾斜的结构(((mults_of_2 + mults_of_3) + mults_of_5) + ...),更频繁产生的流位于更深的位置(因此数字需要通过更多层级上升)。右倾结构应该更好,mults_of_2 + (mults_of_3 + (mults_of_5 + ...))。将其制作成树应该会在时间复杂度上带来真正的改进(通常将其推向约~ n1.2..1.25)。有关讨论,请参见this haskellwiki page

“真正的”埃拉托色尼筛法通常在产生n个质数时运行约为~n1.1,而最佳试除法筛法在约为~n1.40..1.45时运行。您的原始代码运行时间大约为立方级别或更差。使用命令式可变数组通常是最快的,通过段(a.k.a. 埃拉托色尼筛法)工作。

在您的第二个代码中,Python是如何实现这一点的


是的!你说得对!这两点都是我代码中可以进行的明显优化。 - stewSquared

7
有没有一种简单的方法将使用流的代码转换为使用迭代器的代码?或者有没有一种简单的方法使我的第一次尝试更具内存效率?
@Will Ness使用流给出了一个改进的答案,并解释了为什么你的代码需要这么多内存和时间,例如在早期添加流和左倾线性结构。但是,没有人完全回答了你的问题的第二部分(或者可能是主要部分),即能否使用迭代器实现真正的增量埃拉托色尼筛选法。
首先,我们应该归功于这个右倾算法,你的第一个代码是一个粗糙的(左倾)示例(因为它过早地将所有素数复合流添加到合并操作中),这归功于Richard Bird,就像Melissa E. O'Neill关于增量埃拉托色尼筛选法的定义性论文的附录中所述。
其次,不,无法在此算法中用Iterator替换Stream,因为它依赖于在不重新启动流的情况下移动流,并且尽管可以访问迭代器的头部(当前位置),但使用下一个值(跳过头部)生成其余的迭代作为流需要建立一个全新的迭代器,这将极大地消耗内存和时间。然而,我们可以使用Iterator按顺序输出素数序列的结果,以最小化内存使用并使使用高阶函数的迭代器变得容易,正如您将在我的代码中看到的那样。
现在,Will Ness已经向您介绍了推迟将素数复合流添加到计算中直到需要它们的原则,当将它们存储在像Priority Queue或HashMap这样的结构中时,这种方法非常有效,甚至在O'Neill的论文中也被忽略了,但对于Richard Bird算法来说,这并不是必要的,因为未来的流值只有在需要时才会被访问,因此不会被存储,如果Streams被正确地惰性构建(如惰性和左倾),事实上,这个算法甚至不需要完整的Stream的记忆和开销,每个复合数筛选序列仅向前移动,而没有参考任何过去的质数,除了需要一个单独的基本质数源,可以通过同一算法的递归调用来提供。
为了方便参考,让我们将Richard Bird算法的Haskell代码列出如下:
primes = 2:([3..] ‘minus‘ composites)
  where
    composites = union [multiples p | p <− primes]
    multiples n = map (n*) [n..]
    (x:xs) ‘minus‘ (y:ys)
      | x < y = x:(xs ‘minus‘ (y:ys))
      | x == y = xs ‘minus‘ ys
      | x > y = (x:xs) ‘minus‘ ys
    union = foldr merge []
      where
        merge (x:xs) ys = x:merge’ xs ys
        merge’ (x:xs) (y:ys)
          | x < y = x:merge’ xs (y:ys)
          | x == y = x:merge’ xs ys
          | x > y = y:merge’ (x:xs) ys

在下面的代码中,我简化了“minus”函数(称为“minusStrtAt”),因为我们不需要构建一个全新的流,而是可以将组合减法操作与原始生成(在我的情况下仅为奇数)序列相结合。我还简化了“union”函数(将其重命名为“mrgMltpls”)
流操作作为非记忆化通用的Co Inductive Stream(CIS)实现,作为一个通用类,其中类的第一个字段是流当前位置的值,第二个字段是thunk(返回通过另一个函数嵌入闭包参数的流下一个值的零参数函数)。
def primes(): Iterator[Long] = {
  // generic class as a Co Inductive Stream element
  class CIS[A](val v: A, val cont: () => CIS[A])

  def mltpls(p: Long): CIS[Long] = {
    var px2 = p * 2
    def nxtmltpl(cmpst: Long): CIS[Long] =
      new CIS(cmpst, () => nxtmltpl(cmpst + px2))
    nxtmltpl(p * p)
  }
  def allMltpls(mps: CIS[Long]): CIS[CIS[Long]] =
    new CIS(mltpls(mps.v), () => allMltpls(mps.cont()))
  def merge(a: CIS[Long], b: CIS[Long]): CIS[Long] =
    if (a.v < b.v) new CIS(a.v, () => merge(a.cont(), b))
    else if (a.v > b.v) new CIS(b.v, () => merge(a, b.cont()))
    else new CIS(b.v, () => merge(a.cont(), b.cont()))
  def mrgMltpls(mlps: CIS[CIS[Long]]): CIS[Long] =
    new CIS(mlps.v.v, () => merge(mlps.v.cont(), mrgMltpls(mlps.cont())))
  def minusStrtAt(n: Long, cmpsts: CIS[Long]): CIS[Long] =
    if (n < cmpsts.v) new CIS(n, () => minusStrtAt(n + 2, cmpsts))
    else minusStrtAt(n + 2, cmpsts.cont())
  // the following are recursive, where cmpsts uses oddPrms and
  // oddPrms uses a delayed version of cmpsts in order to avoid a race
  // as oddPrms will already have a first value when cmpsts is called to generate the second
  def cmpsts(): CIS[Long] = mrgMltpls(allMltpls(oddPrms()))
  def oddPrms(): CIS[Long] = new CIS(3, () => minusStrtAt(5L, cmpsts()))
  Iterator.iterate(new CIS(2L, () => oddPrms()))
                   {(cis: CIS[Long]) => cis.cont()}
    .map {(cis: CIS[Long]) => cis.v}
}

上述代码在ideone上大约用1.3秒钟生成第100,000个质数(1299709),额外开销约为0.36秒,并且在计算前600,000个质数时具有约为1.43的经验计算复杂度。除程序代码使用的内存外,内存使用可以忽略不计。
虽然可以使用Scala Streams来实现上述代码,但是这会带来性能和内存使用开销(一个常数因子),而该算法并不需要。使用Streams可以直接使用它们,无需额外生成Iterator代码,但由于仅用于序列的最终输出,因此成本不高。
要实现Will Ness建议的一些基本树折叠,只需要添加一个“pairs”函数并将其连接到“mrgMltpls”函数即可。
def primes(): Iterator[Long] = {
  // generic class as a Co Inductive Stream element
  class CIS[A](val v: A, val cont: () => CIS[A])

  def mltpls(p: Long): CIS[Long] = {
    var px2 = p * 2
    def nxtmltpl(cmpst: Long): CIS[Long] =
      new CIS(cmpst, () => nxtmltpl(cmpst + px2))
    nxtmltpl(p * p)
  }
  def allMltpls(mps: CIS[Long]): CIS[CIS[Long]] =
    new CIS(mltpls(mps.v), () => allMltpls(mps.cont()))
  def merge(a: CIS[Long], b: CIS[Long]): CIS[Long] =
    if (a.v < b.v) new CIS(a.v, () => merge(a.cont(), b))
    else if (a.v > b.v) new CIS(b.v, () => merge(a, b.cont()))
    else new CIS(b.v, () => merge(a.cont(), b.cont()))
  def pairs(mltplss: CIS[CIS[Long]]): CIS[CIS[Long]] = {
    val tl = mltplss.cont()
    new CIS(merge(mltplss.v, tl.v), () => pairs(tl.cont()))
  }
  def mrgMltpls(mlps: CIS[CIS[Long]]): CIS[Long] =
    new CIS(mlps.v.v, () => merge(mlps.v.cont(), mrgMltpls(pairs(mlps.cont()))))
  def minusStrtAt(n: Long, cmpsts: CIS[Long]): CIS[Long] =
    if (n < cmpsts.v) new CIS(n, () => minusStrtAt(n + 2, cmpsts))
    else minusStrtAt(n + 2, cmpsts.cont())
  // the following are recursive, where cmpsts uses oddPrms and
  // oddPrms uses a delayed version of cmpsts in order to avoid a race
  // as oddPrms will already have a first value when cmpsts is called to generate the second
  def cmpsts(): CIS[Long] = mrgMltpls(allMltpls(oddPrms()))
  def oddPrms(): CIS[Long] = new CIS(3, () => minusStrtAt(5L, cmpsts()))
  Iterator.iterate(new CIS(2L, () => oddPrms()))
                   {(cis: CIS[Long]) => cis.cont()}
    .map {(cis: CIS[Long]) => cis.v}
}

上述代码在ideone中大约0.75秒内生成了第100,000个质数(1299709),并且有约0.37秒的开销,到第1,000,000个质数(15485863)的经验计算复杂度约为1.09 (5.13秒)。除程序代码使用的内存外,内存使用可以忽略不计。
请注意,上述代码完全是功能性的,没有使用可变状态,但是Bird算法(甚至是树折叠版本)对于更大的范围不如使用优先队列或HashMap快,因为处理树合并的操作数量具有比Priority Queue的log n开销更高的计算复杂度,或者HashMap的线性(分摊)性能(虽然存在大量的常数因子开销来处理哈希,所以这种优势直到使用一些真正大的范围才能看到)。
这些代码之所以使用如此少的内存,是因为CIS流没有永久引用流的起始点,因此在使用流时对其进行垃圾回收,只留下最少量的基本质数复合序列占位符。正如Will Ness所解释的那样,仅使用546个基本质数复合数字流即可生成前100万个素数(最大值为15485863),每个占位符仅占用几十字节(Long数字占用8个字节,64位函数引用占用8个字节,指向闭包参数的指针和函数和类开销各占用另外几个字节),每个流占用的总空间可能只有40字节左右,生成100万个素数的序列所需的总空间不会超过20千字节。

-1

如果您只想要一个无限的素数流,这是我认为最优雅的方法:

def primes = {
  def sieve(from : Stream[Int]): Stream[Int] = from.head #:: sieve(from.tail.filter(_ % from.head != 0))
  sieve(Stream.from(2))
}

1
请注意问题中有 Stream.from(n*n, n),因此 primes 不应筛选所有整数。 - senia
1
这确实是一个无限的质数流。然而,它使用试除法而不是埃拉托斯特尼筛法,即它很慢。在我的实现中,primes.drop(10000).head需要40秒,在你的实现中需要3分钟才能获得GC超限。阅读这篇文章:http://www.cs.hmc.edu/~oneill/papers/Sieve-JFP.pdf此外,我不仅仅是在寻找一个无限质数流,而是想要一种使用Streams而不会出现GC超限的方法。 - stewSquared
13
更好的版本:val primes: Stream[Int] = 2 #:: Stream.from(3, 2).filter(i => primes.takeWhile(j => j * j <= i).forall(k => i % k > 0))。这需要不到一秒钟来运行primes.drop(10000).head - John Landahl
1
@JohnLandahl 感谢您提供这个经典的试除法算法的代码!它运行在*~ n^1.45*,对于 n = 25k..100k,结果如预期一样。 :) - Will Ness
1
@Kigyo 请查看https://dev59.com/jFnUa4cB1Zd3GeqPbYkm,该链接解释了你的代码存在的问题。 - Will Ness
显示剩余4条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接