使用Scala进行N叉树遍历会导致堆栈溢出。

10

我试图从一个N-叉树数据结构中返回一组小部件的列表。在我的单元测试中,如果我大约有2000个小部件,每个小部件只有一个依赖关系,那么我会遇到堆栈溢出的问题。我认为发生了什么是for循环导致我的树遍历不是尾递归的。在scala中编写更好的方法是什么?这是我的函数:

protected def getWidgetTree(key: String) : ListBuffer[Widget] = {
  def traverseTree(accumulator: ListBuffer[Widget], current: Widget) : ListBuffer[Widget] = {
    accumulator.append(current)

    if (!current.hasDependencies) {
      accumulator
    }  else {
      for (dependencyKey <- current.dependencies) {
        if (accumulator.findIndexOf(_.name == dependencyKey) == -1) {
          traverseTree(accumulator, getWidget(dependencyKey))
        }
      }

      accumulator
    }
  }

  traverseTree(ListBuffer[Widget](), getWidget(key))
}

您能否将Widget类与测试用例一起发布? - Petro Semeniuk
这是您需要的Petro:case class Widget(name: String, dependencies: List[String]) - Donuts
3个回答

10

这个函数不是尾递归的原因在于您在函数内部进行了多次递归调用。要实现尾递归,递归调用必须是函数体中的最后一个表达式。毕竟,整个重点在于它像while循环一样工作(因此可以转换为循环)。循环不能在单次迭代内多次调用自身。

要执行这样的树遍历,可以使用队列来传递需要访问的节点。

假设我们有这棵树:

//        1
//       / \  
//      2   5
//     / \
//    3   4

使用这个简单的数据结构来表示:

case class Widget(name: String, dependencies: List[String]) {
  def hasDependencies = dependencies.nonEmpty
}

我们有一个指向每个节点的映射:

val getWidget = List(
  Widget("1", List("2", "5")),
  Widget("2", List("3", "4")),
  Widget("3", List()),
  Widget("4", List()),
  Widget("5", List()))
  .map { w => w.name -> w }.toMap

现在我们可以将您的方法重写为尾递归:

def getWidgetTree(key: String): List[Widget] = {
  @tailrec
  def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
    queue match {
      case currentKey :: queueTail =>        // the queue is not empty
        val current = getWidget(currentKey)  // get the element at the front
        val newQueueItems =                  // filter out the dependencies already known
          current.dependencies.filterNot(dependencyKey => 
            accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey))
        traverseTree(newQueueItems ::: queueTail, current :: accumulator) // 
      case Nil =>                            // the queue is empty
        accumulator.reverse                  // we're done
    }
  }

  traverseTree(key :: Nil, List[Widget]())
}

然后进行测试:

for (k <- 1 to 5)
  println(getWidgetTree(k.toString).map(_.name))

输出:

ListBuffer(1, 2, 3, 4, 5)
ListBuffer(2, 3, 4)
ListBuffer(3)
ListBuffer(4)
ListBuffer(5)

“accumulator.exists(_.name == dependencyKey)” 这行代码在添加成千上万个元素时可能会变得比较慢。有什么方法可以改进这种查找? - Donuts
@John,将累加器中的所有键保存在缓存中(使用Set即可),然后进行检查。这比遍历累加器要好得多。 - dhg
我在traverseTree函数中添加了一个可变的HashSet参数作为“keyAccumulator”,现在我检查它而不是Widget累加器,这显著提高了性能。如果我对每个字符串建立索引并在整个过程中只使用整数,或许我可以进一步提高性能。 - Donuts

4

对于@dhg回答中的相同示例,一个等效的尾递归函数没有可变状态(即ListBuffer):

case class Widget(name: String, dependencies: List[String])

val getWidget = List(
  Widget("1", List("2", "5")),
  Widget("2", List("3", "4")),
  Widget("3", List()),
  Widget("4", List()),
  Widget("5", List())).map { w => w.name -> w }.toMap

def getWidgetTree(key: String): List[Widget] = {
  def addIfNotAlreadyContained(widgetList: List[Widget], widgetNameToAdd: String): List[Widget] = {
    if (widgetList.find(_.name == widgetNameToAdd).isDefined) widgetList
    else                                                      widgetList :+ getWidget(widgetNameToAdd)
  }

  @tailrec
  def traverseTree(currentWidgets: List[Widget], acc: List[Widget]): List[Widget] = currentWidgets match {
    case Nil                                => {
      // If there are no more widgets in this branch return what we've traversed so far
      acc 
    }
    case Widget(name, Nil) :: rest          => {
      // If the first widget is a leaf traverse the rest and add the leaf to the list of traversed
      traverseTree(rest, addIfNotAlreadyContained(acc, name)) 
    }
    case Widget(name, dependencies) :: rest => {
      // If the first widget is a parent, traverse it's children and the rest and add it to the list of traversed
      traverseTree(dependencies.map(getWidget) ++ rest, addIfNotAlreadyContained(acc, name))
    } 
  }

  val root = getWidget(key)
  traverseTree(root.dependencies.map(getWidget) :+ root, List[Widget]())
}

对于相同的测试用例

for (k <- 1 to 5)
  println(getWidgetTree(k.toString).map(_.name).toList.sorted)

给你带来以下好处:
List(2, 3, 4, 5, 1)
List(3, 4, 2)
List(3)
List(4)
List(5)

请注意,这是后序遍历而不是前序遍历。

1
太棒了!谢谢。我不知道@tailrec注释。那是一个相当酷的小宝石。我必须稍微调整一下解决方案,因为带有自我引用的小部件会导致无限循环。而且,当调用traverseTree时,newQueueItems是一个Iterable,而它期望一个List,所以我必须将其转换为List。
def getWidgetTree(key: String): List[Widget] = {
  @tailrec
  def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
    queue match {
      case currentKey :: queueTail =>        // the queue is not empty
        val current = getWidget(currentKey)  // get the element at the front
        val newQueueItems =                  // filter out the dependencies already known
          current.dependencies.filter(dependencyKey =>
            !accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey)).toList
        traverseTree(newQueueItems ::: queueTail, current :: accumulator) //
      case Nil =>                            // the queue is empty
        accumulator.reverse                  // we're done
    }
  }

  traverseTree(key :: Nil, List[Widget]())
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接