当满足某些条件时,使用Scala的foldLeft函数

10

如何在Scala中模拟以下行为?即在累加器满足某些特定条件的情况下继续折叠。

def foldLeftWhile[B](z: B, p: B => Boolean)(op: (B, A) => B): B
例如
scala> val seq = Seq(1, 2, 3, 4)
seq: Seq[Int] = List(1, 2, 3, 4)
scala> seq.foldLeftWhile(0, _ < 3) { (acc, e) => acc + e }
res0: Int = 1
scala> seq.foldLeftWhile(0, _ < 7) { (acc, e) => acc + e }
res1: Int = 6

更新:

根据 @Dima 的答案,我意识到我的意图有点带有副作用。因此我将其与 takeWhile 同步,即如果谓词不匹配,则不会有进展。并添加一些更多的例子,使其更清晰。(注意:这将无法与 Iterator 一起使用)


好吧,我猜这不在标准库中,因为如果条件来自累加器本身,则可以直接将其合并,例如(acc,e) => if (acc <3) acc + e else acc。您甚至可以创建一个高阶函数,从其基础和条件创建“最终”累加器。我理解这可能比从折叠中提前中断效率低,但否则它看起来是等效的。 - GPI
@GPI,同意。但我认为这不仅仅是效率的问题。正如我们在下面讨论的那样,谓词可能会匹配另一个条件,从而产生意料之外的结果。 - ntviet18
6个回答

8

首先需要指出的是,你提供的例子似乎存在错误。如果我理解正确,结果应该是1(满足谓词_ < 3的最后一个值),而不是6

最简单的方法是使用return语句,但在Scala中这种做法非常不被赞同。但出于完备性的考虑,我还是会提一下。

def foldLeftWhile[A, B](seq: Seq[A], z: B, p: B => Boolean)(op: (B, A) => B): B = foldLeft(z) { case (b, a) => 
   val result = op(b, a) 
   if(!p(result)) return b
   result
}

由于我们希望避免使用return,scanLeft可能是一种选择:

seq.toStream.scanLeft(z)(op).takeWhile(p).last

这种方法有些浪费,因为它会累积所有(匹配的)结果。你可以使用iterator代替toStream来避免这种情况,但是由于某些原因,Iterator没有.last方法,所以你需要显式地再次扫描它一次:

 seq.iterator.scanLeft(z)(op).takeWhile(p).foldLeft(z) { case (_, b) => b }

谢谢,我认为你的第一个解决方案是一个非常好的想法,但我不确定它是否在所有情况下都有效,例如 Seq(1, 2, 2, 1).foldLeftWhile(0, _ < 5) { case (acc, e) => acc + e } 可能会返回4而不是3。 - ntviet18
我为你的第二个解决方案点了赞。非常简洁 :) - ntviet18
我进行了快速检查,结果发现你的第一个解决方案也像魔术一样奏效。请问为什么return语句会忽略foldLeft闭包? - ntviet18
1
@ntviet18,是的,这就是Scala中return的魔力。它基本上忽略了其周围所有的lambda,并返回到最近方法上面的级别。它通过抛出一个特殊的异常来实现,该异常在方法末尾被捕获。这就是为什么它通常被认为是“代码异味” - 太多的魔法,没有引用透明度,并且通常很难理解。 - Dima
感谢您的解释。 - ntviet18

4

在Scala中,定义你想要的内容是非常直观的。你可以定义一个隐式类,将你的函数添加到任何TraversableOnce(包括Seq)中。

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft(init)((acc, next) => if (where(acc)) op(acc, next) else acc)
  }
}
Seq(1,2,3,4).foldLeftWhile(0)(_ < 3)((acc, e) => acc + e)

更新,因为问题已经修改:

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft((init, false))((a,b) => if (a._2) a else {
      val r = op(a._1, b)
      if (where(r)) (op(a._1, b), false) else (a._1, true)
    })._1
  }
}

请注意,我将您的(z: B, p: B => Boolean)拆分为两个高阶函数。这只是我个人在Scala编程中的风格偏好。

我认为,这违反了“while”对于非“单调”的谓词的语义: Seq(1,2,3,4).foldLeftWhile(0)(_ % 2 != 0) { case (b, a) => b } 返回 3,但应该是 1 - Dima
因为它似乎回答了我的原始问题,所以我点赞了。谢谢@Dima - ntviet18
@Dima 我想你的意思是'(a, b) => b',因为'(b, a) => b'总是返回您的起始值。无论是否应用我的更正,它都会返回0,因为0%2!= 0立即返回false,并且您的折叠函数实际上从未运行。您是否打算传递除0以外的其他初始值?如果您采用我的更正并传递1,则会得到2。这不是您想要或期望的吗? - Jack Davidson
@JackDavidson,是的,我在那个例子中搞砸了。重点是如果谓词变为false,你的函数会继续执行...所以,如果它再次变为true,它将重新开始累积。因为acc是唯一的参数,只要谓词是纯的就没关系,但它也取决于一些外部状态,这会破坏语义... - Dima
@dima 是的,那是正确的。如果谓词不是纯函数,就不能保证函数不会恢复处理输入。我已经更新了一个替代方案,反映了新问题,并应该永久停止第一次函数返回 false。 - Jack Davidson

2
这个怎么样:
def foldLeftWhile[A, B](z: B, xs: Seq[A], p: B => Boolean)(op: (B, A) => B): B = {
  def go(acc: B, l: Seq[A]): B = l match {
    case h +: t => 
        val nacc = op(acc, h)
        if(p(nacc)) go(op(nacc, h), t) else nacc
    case _ => acc
  }
  go(z, xs)
}

val a = Seq(1,2,3,4,5,6)
val r = foldLeftWhile(0, a, (x: Int) => x <= 3)(_ + _)
println(s"$r")

当谓词为真时,递归迭代集合,然后返回累加器。

您可以在scalafiddle上尝试它。


1
这会进行额外的迭代:p(result)总是_false(除非它一直保持为真,直到结束) - Dima
谢谢,这是对我最初问题的答案。但我同意 @Dima 的观点,我们不应该再深入探讨了。所以,我给你的回答点了赞。 - ntviet18
@Dima,我不明白,你能解释一下为什么会多做一次迭代吗?我没看出来。谢谢! - Alejandro Alcalde
1
@ElBaulP 上一次 p(acc) 是 true,即使现在它是 false,你还是调用了 go。你必须先计算结果,再检查条件,如果是 true 就传递到 go - Dima

1

过了一会儿,我收到了很多看起来不错的答案。因此,我将它们合并成了这篇文章。

@Dima 提供了一个非常简洁的解决方案。

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    seq.toStream.scanLeft(z)(op).takeWhile(p).lastOption.getOrElse(z)
  }
}

由@ElBaulP编写(我稍作修改以匹配@Dima的评论)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    @tailrec
    def foldLeftInternal(acc: B, seq: Seq[A]): B = seq match {
      case x :: _ =>
        val newAcc = op(acc, x)
        if (p(newAcc))
          foldLeftInternal(newAcc, seq.tail)
        else
          acc
      case _ => acc
    }

    foldLeftInternal(z, seq)
  }
}

我(涉及副作用)的回答

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    var accumulator = z
    seq
      .map { e =>
        accumulator = op(accumulator, e)
        accumulator -> e
      }
      .takeWhile { case (acc, _) =>
        p(acc)
      }
      .lastOption
      .map { case (acc, _) =>
        acc
      }
      .getOrElse(z)
  }
}

1

第一个例子:对于每个元素的谓词

首先,您可以使用内部尾递归函数

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case head :: tail if f(head) => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

或简短版本

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case head :: tail if f(head) => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}

然后使用它

val a = List(1, 2, 3, 4, 5, 6).foldLeftWhile(0, _ < 3)(_ + _) //a == 3

第二个例子:累加器的值:

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case _ if !f(z) => z
      case head :: tail => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

或简短版本

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case _ if !f(z) => z
    case head :: tail => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}

你的解决方案非常好。如果我们扩展AnyVal来创建值类,我认为它会非常高效 :) - ntviet18
在我重新考虑之后,你应该将谓词应用于累加器,而不是元素。 - ntviet18

-1

只需在累加器上使用分支条件:

seq.foldLeft(0, _ < 3) { (acc, e) => if (acc < 3) acc + e else acc}

然而,您将运行序列的每个条目。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接