高效地遍历数组以匹配布尔条件?

6
这是另一个关于《Scala快学快用》的问题,要求编写一个函数lteqgt(values: Array[Int], v: Int),返回一个三元组,其中包含小于v、等于v和大于v的值的计数。
我的实现方式如下:
scala> def lteqgt(values: Array[Int], v: Int): (Int, Int, Int) = (values.count(_ < v), values.count(_ == v), values.count(_ > v))
lteqgt: (values: Array[Int], v: Int)(Int, Int, Int)

scala> lteqgt(Array(0,0,1,1,1,2,2), 1)
res47: (Int, Int, Int) = (2,3,2)

问题?
我正在遍历数组3次以收集计数,有没有一种方法可以在第一次遍历中收集值?一种惯用的方式吗?

3个回答

10

使用foldLeft是完美的情况。它将正好遍历您的集合一次,而不创建另一个集合(如groupBy所做的那样),并且与更通用的aggregate相比更加简洁。

def lteqgt(values: Array[Int], v: Int): (Int, Int, Int) =
  values.foldLeft((0, 0, 0)) {
    case ((lt, eq, gt), el) =>
      if (el < v) (lt + 1, eq, gt)
      else if (el == v) (lt, eq + 1, gt)
      else (lt, eq, gt + 1)
  }

如果你想实现最终的效率,同时避免使用命令式方法,那么尾递归是一个可行的方案:

def lteqgt(values: Array[Int], v: Int): (Int, Int, Int) = {
  def rec(i:Int, lt:Int, eq:Int, gt:Int):(Int, Int, Int) =
    if (i == values.length) (lt, eq, gt)
    else if (values(i) < v) rec(i + 1, lt + 1, eq, gt)
    else if (values(i) == v) rec(i + 1, lt, eq + 1, gt)
    else rec(i + 1, lt, eq, gt + 1)
  rec(0, 0, 0, 0)
}

这样可以避免在每次迭代中构建 Tuple 和装箱的 Int。整个过程会编译成 Java 的 while 循环(如果您感兴趣,这里是转换后的输出)。


2

虽然功能性解决方案可能更加优雅,但我们不要忘记“无聊”但高效的命令式解决方案。

def lteqgt(values: Array[Int], v: Int): (Int, Int, Int) = {
    var lt = 0
    var eq = 0
    var gt = 0
    values.foreach (i => {
      if      (i<v)  lt += 1
      else if (i==v) eq += 1
      else           gt += 1
    })
    (lt,eq,gt)
  }

如下代码可能会在每次循环中生成一个函数调用,正如Aivean在下面指出的那样。为了提高效率,我们可以手动去掉闭包。(尽管编译器尚未进行这样的优化,这是有点遗憾的)

 def lteqgt(values: Array[Int], v: Int): (Int, Int, Int) = {
   var lt = 0
   var eq = 0
   var gt = 0
   var i = 0
   while (i < values.length) {
     if      (values(i) < v ) lt += 1
     else if (values(i) == v) eq += 1
     else                     gt += 1
     i += 1
   }
   (lt,eq,gt)
  }

1
它可能没有看起来那么高效。将您的代码转换为Java尾递归解决方案进行比较。我没有对其进行基准测试,但我认为每次迭代的函数调用可能不太有效率。 - Aivean
2
你可能想使用 while 循环。Scala 不支持 Java 风格的 for 循环。Scala 的 for 循环会被编译成 foreach/map - Aivean
@Aivean 谢谢。已更正。 - chi

0

与@chi的类似,但让我们稍微超出设计规范,这样我们就可以比较除了Ints之外的其他数据。而且标题说“遍历”,所以让我们使用Traversable,这样我们就可以分析所有其他集合。机器可能不会更快地执行此操作,但这可能是程序员时间高效的。

def lteqgt[T](values:Traversable[T], v:T)(implicit cmp: Ordering[T]):(Int,Int,Int) =  {
            var l = 0;
            var e = 0; 
            var g = 0;
            for(x <- values){
                if (cmp.equiv(v,x)) e += 1
                else if (cmp.gt(x,v)) g += 1
                else l += 1
            }
        (l,e,g);
};

使用方法:

scala> lteqgt(List(1.0,2.0,3.0,2.5), 2.5)
res0: (Int, Int, Int) = (2,1,1)

scala> lteqgt(Array(1.0,2.0,3.0,2.5), 2.5)
res1: (Int, Int, Int) = (2,1,1)

scala> lteqgt(Vector(1.0,2.0,3.0,2.5), 2.5)
res2: (Int, Int, Int) = (2,1,1)

scala> lteqgt(Set(1.0,2.0,3.0,2.5), 2.5)
res3: (Int, Int, Int) = (2,1,1)

scala> lteqgt(1 to 100, 45)
res4: (Int, Int, Int) = (44,1,55)

scala> lteqgt(Range(0,100,3), 50)
res5: (Int, Int, Int) = (17,0,17)

scala> lteqgt(List("fee","fie","foe","fum"), "foo")
res6: (Int, Int, Int) = (3,0,1)

scala> lteqgt(None, 1)
res7: (Int, Int, Int) = (0,0,0)

scala> lteqgt(Some(2), 1)
res8: (Int, Int, Int) = (0,0,1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接