Scala:Seq[T]元素的功能聚合=> Seq [Seq[T]](保留顺序)

5

我希望将序列中兼容的元素聚合起来,即将Seq[T]转换为Seq[Seq[T]],其中每个子序列中的元素彼此兼容,同时保留原始序列的顺序,例如从

case class X(i: Int, n: Int) {
  def canJoin(that: X): Boolean = this.n == that.n
  override val toString = i + "." + n
}
val xs = Seq(X(1, 1), X(2, 3), X(3, 3), X(4, 3), X(5, 1), X(6, 2), X(7, 2), X(8, 1))
/* xs = List(1.1, 2.3, 3.3, 4.3, 5.1, 6.2, 7.2, 8.1) */

希望获得

val js = join(xs)
/* js = List(List(1.1), List(2.3, 3.3, 4.3), List(5.1), List(6.2, 7.2), List(8.1)) */

我尝试以函数式的方式进行此操作,但是我卡在了一半:

使用while循环实现

def split(seq: Seq[X]): (Seq[X], Seq[X]) = seq.span(_ canJoin seq.head)
def join(seq: Seq[X]): Seq[Seq[X]] = {
  var pp = Seq[Seq[X]]()
  var s = seq
  while (!s.isEmpty) {
    val (p, r) = split(s)
    pp :+= p
    s = r
  }
  pp
}

我对 split 感到满意,但是 join 稍微有点长。

在我看来,这是一个标准任务。这引出了以下问题:

  1. 是否有在集合库中的函数可以减少代码长度?
  2. 或者也许有一种不同的方法来解决这个任务?特别是与Rewriting a sequence by partitioning and collapsing中的方法不同的方法?

使用尾递归替换 while 循环

def join(xs: Seq[X]): Seq[Seq[X]] = {
  @annotation.tailrec
  def jointr(pp: Seq[Seq[X]], rem: Seq[X]): Seq[Seq[X]] = {
    val (p, r) = split(rem)
    val pp2 = pp :+ p
    if (r.isEmpty) pp2 else jointr(pp2, r)
  }
  jointr(Seq(), xs)
}

我的答案不是尾递归的,因为最后一个操作是Seq连接。 - Peter Schmitz
我会将 jointr 放入 join 的定义中作为一个辅助函数。 - Peter Schmitz
@Peter Schmitz:按照建议,将jointr作为join的辅助定义。 - binuWADa
作为一个小提示,仅在您不知道的情况下:jointr(Seq(),xs)就足够了,类型会被推断出来。 - Peter Schmitz
2
这可能有些过度,但处理元素对是非常常见的,我在改进集合库的示例中使用了它:https://dev59.com/dW035IYBdhLWcg3wc_sm。通过这个改进,只需要 xs.groupedWhile(_ canJoin _) (也就是说,它解决了同样的问题)。虽然不太简单,但它是最灵活的(返回正确的类型,在数组上工作等)。 - Rex Kerr
显示剩余2条评论
3个回答

8
def join(seq: Seq[X]): Seq[Seq[X]] = {
  if (seq.isEmpty) return Seq()
  val (p,r) = split(seq)
  Seq(p) ++ join(r)
}

非常好(简短易懂):-) - binuWADa

4
这里是 foldLeft 版本:
def join(seq: Seq[X]) = xs.reverse.foldLeft(Nil: List[List[X]]) {
    case ((top :: group) :: rest, x) if x canJoin top => 
        (x :: top :: group) :: rest
    case (list, x) => (x :: Nil) :: list
} 

还有一个foldRight的版本(在这种情况下,您不需要reverse列表):

def join(seq: Seq[X]) = xs.foldRight(Nil: List[List[X]]) {
    case (x, (top :: group) :: rest) if x canJoin top => 
        (x :: top :: group) :: rest
    case (x, list) => (x :: Nil) :: list
} 

1
我已经尝试过它,而且它可以工作。我只是还不太熟悉列表,只能凭直觉处理。但在安静的时刻,我会再仔细看一遍,代码已经看起来很好了。 - binuWADa

3

基准测试

因为时间较充足,我问自己,在获取轻量级语法背后是否隐藏着沉重的结构时,不同方法的运行时间如何。

因此,我创建了一个微型基准测试,用于测量三个序列的运行时间。

(1, 3, 3, 3, 1, 2, 2, 1)
(1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 8, 7, 6, 5, 4, 3, 3, 3, 2, 1, 2, 3)
(2, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 8, 8, 8, 8, 7, 6, 5, 4, 4, 4, 4, 3, 3, 3, 2, 1)

并且得到了以下结果:

总结

编辑:新结果(开始):

在将结果整合到我的实际项目中时,我发现基准测试存在不一致性。因此,我再次进行了基准测试,并增加了更多的热身轮次(现在为1000轮),以便JIT编译器充分利用代码。这样就重新排序了结果,并产生了一个新的最爱:X7(pimp my lib)= 无愧于享受乐趣。而List版本X8(reverse.foldLeft)现在也非常快。

Nr (Approach)                      Running time (ns)  Contributor
X2 (poor.reference.impl)            in    15.202 ns
X1 (original while loop)            in     8.166 ns
X3 (tail recursion)                 in     7.473 ns
X4 (recursion with ++)              in     6.671 ns   Peter Schmitz
X5 (simplified recursion with ++)   in     6.161 ns   Peter Schmitz
X6 (foldRight)                      in     4.083 ns   tenshi
X7 (pimp my lib)                    in     1.677 ns   Rex Kerr
X8 (reverse.foldLeft)               in     1.349 ns   tenshi

编辑:新结果(结尾)

旧结果:

Nr (Approach)                      Running time (ns)  Contributor
X2 (poor.reference.impl)            in 2.972.015 ns
X7 (pimp my lib)                    in 1.185.599 ns   Rex Kerr
X3 (tail recursion)                 in 1.027.008 ns
X8 (reverse.foldLeft)               in   643.840 ns   tenshi
X6 (foldRight)                      in   608.112 ns   ""
X1 (original while loop)            in   564.726 ns
X4 (recursion with ++)              in   468.478 ns   Peter Schmitz
X5 (simplified recursion with ++)   in   447.699 ns   ""

Details

X2 (poor.reference.impl)

// in    15.202 ns
import collection.mutable.ArrayBuffer
def join2(seq: Seq[X]): Seq[Seq[X]] = {
  var pp = Seq[ArrayBuffer[X]](ArrayBuffer(seq(0)))
  for (i <- 1 until seq.size) {
    if (seq(i) canJoin seq(i - 1)) {
      pp.last += seq(i)
    } else {
      pp :+= ArrayBuffer(seq(i))
    }
  }
  pp
}

X1 (while循环)

// in     8.166 ns
def join(xs: Seq[X]): Seq[Seq[X]] = {
  var xss = Seq.empty[Seq[X]]
  var s = xs
  while (!s.isEmpty) {
    val (p, r) = split(s)
    xss :+= p
    s = r
  }
  xss
}

这是问题开始时的原始命令式方法。

X3(尾递归)

// in     7.473 ns
def join(xs: Seq[X]): Seq[Seq[X]] = {
  @annotation.tailrec
  def jointr(xss: Seq[Seq[X]], rxs: Seq[X]): Seq[Seq[X]] = {
    val (g, r) = split(rxs)
    val xsn = xss :+ g
    if (r.isEmpty) xsn else jointr(xsn, r)
  }
  jointr(Seq(), xs)
}

X4(利用++进行递归)

// in     6.671 ns
def join(seq: Seq[X]): Seq[Seq[X]] = {
  if (seq.isEmpty) return Seq()
  val (p, r) = split(seq)
  Seq(p) ++ join(r)
}

X5(使用++简化递归)

// in     6.161 ns
def join(xs: Seq[X]): Seq[Seq[X]] = if (xs.isEmpty) Seq() else {
  val (p, r) = split(xs)
  Seq(p) ++ join(r)
}

简化后的代码几乎相同,但速度略有提升。

X6(foldRight)

// in     4.083 ns
def join(xs: Seq[X]) = xs.foldRight(Nil: List[List[X]]) {
  case (x, (top :: group) :: rest) if x canJoin top => (x :: top :: group) :: rest
  case (x, list)                                    => (x :: Nil) :: list
}

尝试避免使用reverse,但对于列表来说,foldRight似乎比reverse.foldLeft还要差一点。

X7(优化我的库)

// in     1.677 ns
import collection.generic.CanBuildFrom
class GroupingCollection[A, C, D[C]](ca: C)(
    implicit c2i: C => Iterable[A],
    cbf: CanBuildFrom[C, C, D[C]],
    cbfi: CanBuildFrom[C, A, C]) {
  def groupedWhile(p: (A, A) => Boolean): D[C] = {
    val it = c2i(ca).iterator
    val cca = cbf()
    if (!it.hasNext) cca.result
    else {
      val as = cbfi()
      var olda = it.next
      as += olda
      while (it.hasNext) {
        val a = it.next
        if (p(olda, a)) as += a
        else { cca += as.result; as.clear; as += a }
        olda = a
      }
      cca += as.result
    }
    cca.result
  }
}
implicit def collections_have_grouping[A, C[A]](ca: C[A])(
  implicit c2i: C[A] => Iterable[A],
  cbf: CanBuildFrom[C[A], C[A], C[C[A]]],
  cbfi: CanBuildFrom[C[A], A, C[A]]) = {
  new GroupingCollection[A, C[A], C](ca)(c2i, cbf, cbfi)
}
// xs.groupedWhile(_ canJoin _)

X8 (reverse.foldLeft)

// in     1.349 ns
def join(xs: Seq[X]) = xs.reverse.foldLeft(Nil: List[List[X]]) {
  case ((top :: group) :: rest, x) if x canJoin top => (x :: top :: group) :: rest
  case (list, x)                                    => (x :: Nil) :: list
}

结论

不同的方法(X1、X3、X4、X5、X6)都处于同一水平。

由于X7(优化我的库)允许非常简洁的用法xs.groupedWhile(_ canJoin _),并且可以将必要的代码隐藏在自己的实用库中,因此我决定在我的真实项目中使用它。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接