Scala中的叉积

59

我希望在Scala中有一个可遍历的二元算子cross(叉积/笛卡尔积):

val x = Seq(1, 2)
val y = List('hello', 'world', 'bye')
val z = x cross y    # i can chain as many traversables e.g. x cross y cross w etc

assert z == ((1, 'hello'), (1, 'world'), (1, 'bye'), (2, 'hello'), (2, 'world'), (2, 'bye'))

在Scala中,仅使用Scala本身(而不是像scalaz这样的库)实现此操作的最佳方法是什么?


"cross" 的类型是什么? - Jesper Nordenberg
1
可能是两个列表的笛卡尔积的重复问题。 - Martin Thoma
8个回答

90

你可以在Scala 2.10中使用隐式类和for推导式来非常简单地实现这一点:

implicit class Crossable[X](xs: Traversable[X]) {
  def cross[Y](ys: Traversable[Y]) = for { x <- xs; y <- ys } yield (x, y)
}

val xs = Seq(1, 2)
val ys = List("hello", "world", "bye")

现在:

scala> xs cross ys
res0: Traversable[(Int, String)] = List((1,hello), (1,world), ...

在2.10版本之前这也是可能的,只不过不够简洁,因为你需要定义一个类和一个隐式转换方法。

你也可以这样写:

scala> xs cross ys cross List('a, 'b)
res2: Traversable[((Int, String), Symbol)] = List(((1,hello),'a), ...

如果你想让 xs cross ys cross zs 返回一个 Tuple3,你需要大量的样板代码或者像Shapeless这样的库。


1
你可以通过为隐式类扩展AnyVal来进行优化和使用ValueClass。ValueClass - twillouer
2
谢谢,但我希望 x cross y cross z 返回的是一个 Tuple3 而不是一个 Tuple2(Tuple2, Value) - pathikrit
3
好的,请提供一个使用shapeless的示例。 - pathikrit
2
重载一个方法以具有不同的返回类型并不是一个好主意。在这种情况下最好使用HList/HArray。 - Jesper Nordenberg
2
@JesperNordenbergпјҡдҪ еә”иҜҘеҸҜд»ҘйҖҡиҝҮдёҺProductLensе’Ң~ж–№жі•дёӯзңӢеҲ°зҡ„зӣёеҗҢж–№жі•жқҘе®үе…Ёең°иҺ·еҫ—x cross y cross zгҖӮ - Travis Brown
显示剩余4条评论

35

使用以下代码交叉x_listy_list

val cross = x_list.flatMap(x => y_list.map(y => (x, y)))

14

下面是任意数量列表的递归叉乘实现:

def crossJoin[T](list: Traversable[Traversable[T]]): Traversable[Traversable[T]] =
  list match {
    case xs :: Nil => xs map (Traversable(_))
    case x :: xs => for {
      i <- x
      j <- crossJoin(xs)
    } yield Traversable(i) ++ j
  }

crossJoin(
  List(
    List(3, "b"),
    List(1, 8),
    List(0, "f", 4.3)
  )
)

res0: Traversable[Traversable[Any]] = List(List(3, 1, 0), List(3, 1, f), List(3, 1, 4.3), List(3, 8, 0), List(3, 8, f), List(3, 8, 4.3), List(b, 1, 0), List(b, 1, f), List(b, 1, 4.3), List(b, 8, 0), List(b, 8, f), List(b, 8, 4.3))

1
代码很好,但在for语句之外进行递归调用可能更有效率。此外,添加一个case Nil => Nil可以捕获空列表的边界情况。 - Tim
@Tim,你如何在for语句之外调用递归? - Milad Khajavi
1
我在我的另一个答案中发布了对您答案的修改。这是问题链接 - Tim

10

猫用户的替代方案:

List[List[A]] 上的 sequence 创建了一个叉积:

import cats.implicits._

val xs = List(1, 2)
val ys = List("hello", "world", "bye")

List(xs, ys).sequence 
//List(List(1, hello), List(1, world), List(1, bye), List(2, hello), List(2, world), List(2, bye))

3

这里有一个类似于Milad's response的内容,但是不是递归的。

def cartesianProduct[T](seqs: Seq[Seq[T]]): Seq[Seq[T]] = {
  seqs.foldLeft(Seq(Seq.empty[T]))((b, a) => b.flatMap(i => a.map(j => i ++ Seq(j))))
}

基于这篇博客文章


2
class CartesianProduct(product: Traversable[Traversable[_ <: Any]]) {
  override def toString(): String = {
    product.toString
  }

  def *(rhs: Traversable[_ <: Any]): CartesianProduct = {
      val p = product.flatMap { lhs =>
        rhs.map { r =>
          lhs.toList :+ r
        }
      }

      new CartesianProduct(p)
  }
}

object CartesianProduct {
  def apply(traversable: Traversable[_ <: Any]): CartesianProduct = {
    new CartesianProduct(
      traversable.map { t =>
        Traversable(t)
      }
    )
  }
}

// TODO: How can this conversion be made implicit?
val x = CartesianProduct(Set(0, 1))
val y = List("Alice", "Bob")
val z = Array(Math.E, Math.PI)

println(x * y * z) // Set(List(0, Alice, 3.141592653589793), List(0, Alice, 2.718281828459045), List(0, Bob, 3.141592653589793), List(1, Alice, 2.718281828459045), List(0, Bob, 2.718281828459045), List(1, Bob, 3.141592653589793), List(1, Alice, 3.141592653589793), List(1, Bob, 2.718281828459045))

// TODO: How can this conversion be made implicit?
val s0 = CartesianProduct(Seq(0, 0))
val s1 = Seq(0, 0)

println(s0 * s1) // List(List(0, 0), List(0, 0), List(0, 0), List(0, 0))

1
当你执行 Seq(0, 0) * Seq(0, 0) 会发生什么?我期望会得到4个项目,但实际上只返回了1个项目。而且这种类型不太适合我。也许可以使用HLists来解决? - pathikrit
我不知道你所说的“非确定性”是什么意思。你是指Set是无序的吗? - Noel Yap
是的,Set是无序的。 - pathikrit
我已经更改了代码以使用列表。结果仍然取决于第一个参数是否使用Set作为底层类型。我想这是由于使用了Traversable - Noel Yap

0

需要进行一些小的编辑。请添加到mapN文档的链接。定义xs,ys或重用问题中的x,y。采用不同的方法也是可以接受的解决方案。 - devilpreet

0

和其他回答类似,这是我的方法。

def loop(lst: List[List[Int]],acc:List[Int]): List[List[Int]] = {
  lst match {
    case head :: Nil => head.map(_ :: acc)
    case head :: tail => head.flatMap(x => loop(tail,x :: acc))
    case Nil => ???
  }
}
val l1 = List(10,20,30,40)
val l2 = List(2,4,6)
val l3 = List(3,5,7,9,11)

val lst = List(l1,l2,l3)

loop(lst,List.empty[Int])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接