给定一个二叉搜索树及其根节点,打印出所有产生相同BST的节点序列。

21

给定一个二叉搜索树,找到从根节点开始的所有节点序列,它们会生成相同的二叉搜索树。

给定一个BST,例如:

  3
 /  \
1    5

答案应该是3,1,5和3,5,1。

另一个例子

       5
     /   \
    4     7
   /     / \
  1     6   10

输出结果将会是

5,4,1,7,6,10

5,4,7,6,10,1

5,7,6,10,4,1

等等

然而,这里的不变条件是父节点的索引必须始终小于其子节点。我在实现时遇到了困难。


你的意思是给定节点的二叉树表示中的数量吗? - Dineshkumar
可能是查找给定整数序列产生相同二叉搜索树的排列数量的重复问题。 - Gaurang Tandon
9个回答

24

我假设您想要一个列表,其中包含将生成相同BST的所有序列。
在此答案中,我们将使用分而治之。 我们将创建一个名为findAllSequences(Node * ptr)的函数,该函数以节点指针作为输入并返回将生成从ptr悬挂的子树的所有不同序列。 此函数将返回一个Vector of Vector of int,即包含所有序列的vector<vector<int>>

生成序列的主要思想根必须在其所有子节点之前

算法:

基本情况1:
如果ptrNULL,则返回一个带有空序列的向量。

if (ptr == NULL) {
    vector<int> seq;
    vector<vector<int> > v;
    v.push_back(seq);
    return v;
}

基础情形2:
如果ptr是一个叶子节点,则返回一个具有单个序列的向量。显然这个序列只包含一个元素,即该节点的值。

if (ptr -> left == NULL && ptr -> right == NULL) {
    vector<int> seq;
    seq.push_back(ptr -> val);
    vector<vector<int> > v;
    v.push_back(seq);
    return v;
}

分割部分这一部分非常简单。
我们假设有一个可以解决此问题的函数,因此我们将其用于左子树和右子树。

vector<vector<int> > leftSeq  = findAllSeq(ptr -> left);
vector<vector<int> > rightSeq = findAllSeq(ptr -> right);

合并这两个解决方案。(关键在于这一步。
到目前为止,我们有两个包含不同序列的集合:

i. leftSeq  - all sequences in this set will generate left subtree.
ii. rightSeq - all sequences in this set will generate right subtree.

现在左子树中的每个序列都可以与右子树中的每个序列合并。 在合并时,我们应该注意保留元素的相对顺序。 同时,在每个合并的序列中,我们将在开头添加当前节点的值,因为根节点必须位于所有子节点之前。

合并的伪代码

vector<vector<int> > results
for all sequences L in leftSeq
    for all sequences R in rightSeq
        create a vector flags with l.size() 0's and R.size() 1's
        for all permutations of flag
            generate the corresponding merged sequence.
            append the current node's value in beginning
            add this sequence to the results.

return results. 

解释: 假设我们从集合 leftSeq 中取一个大小为 n 的序列 L,从集合 rightSeq 中取一个大小为 m 的序列 R
现在这两个序列可以m+nCn种方式合并!
证明: 在合并后,新的序列将有 m + n 个元素。因为我们必须保持相对顺序,所以首先我们将在总共 (m+n) 个位置中任意填充 L 中的所有 n 个元素。之后,剩余的 m 个位置可以用 R 的元素填充。因此,我们必须从 (m+n) 个位置中选择 n 个位置
为了做到这一点,让我们创建一个布尔向量,称为 flags,并用 n0'sm1's 填充它。值为 0 表示来自 left 序列的成员,而值为 1 表示来自 right 序列的成员。现在,我们只需要生成这些 flags 向量的所有 permutations,可以使用 next_permutation 来完成。对于 flags 的每个排列,我们将得到一个不同的 LR 合并序列。
例如: 假设 L={1, 2, 3},R={4, 5}
所以,n=3m=2
因此,我们可以有3+2C3 种合并序列,即10种。
1.最初,flags= {0 0 0 1 1},填充了3个 0's 和 2个 1's
这将导致以下合并序列: 1 2 3 4 5
2.调用 nextPermutation 后,我们将获得
flags = {0 0 1 0 1}
这将生成序列: 1 2 4 3 5
3.再次调用 nextPermutation 后,我们将获得
flags = {0 0 1 1 0}
这将生成序列: 1 2 4 5 3
以此类推...

C++ 代码

vector<vector<int> > findAllSeq(TreeNode *ptr)
{
    if (ptr == NULL) {
        vector<int> seq;
        vector<vector<int> > v;
        v.push_back(seq);
        return v;
    }


    if (ptr -> left == NULL && ptr -> right == NULL) {
        vector<int> seq;
        seq.push_back(ptr -> val);
        vector<vector<int> > v;
        v.push_back(seq);
        return v;
    }

    vector<vector<int> > results, left, right;
    left  = findAllSeq(ptr -> left);
    right = findAllSeq(ptr -> right);
    int size = left[0].size() + right[0].size() + 1;

    vector<bool> flags(left[0].size(), 0);
    for (int k = 0; k < right[0].size(); k++)
        flags.push_back(1);

    for (int i = 0; i < left.size(); i++) {
        for (int j = 0; j < right.size(); j++) {
            do {
                vector<int> tmp(size);
                tmp[0] = ptr -> val;
                int l = 0, r = 0;
                for (int k = 0; k < flags.size(); k++) {
                    tmp[k+1] = (flags[k]) ? right[j][r++] : left[i][l++];
                }
                results.push_back(tmp);
            } while (next_permutation(flags.begin(), flags.end()));
        }
    }

    return results;
}

更新于2017年3月3日:如果原始树包含重复项,则此解决方案不会完全起作用。


如果我只想计算这些序列的数量,即results.size(),而不是枚举它们,会怎样呢?它能够扩展到N=50,即50个整数吗? - evandrix
@Atul,你认为 vector<vector<int>> 会返回子树的所有序列吗?因为每次调用都会初始化一个新的序列。 - Rahul
@Rahul 是的。这个方法可行,因为它是一个递归解决方案。 - vaibhavatul47

5

我为你编写了一份清晰,简洁且文档完备的解决方案,使用Python 3 编写。希望它可以帮到您!

代码: bst_sequences.py

from binarytree import bst, Node


def weave_lists(first: list, second: list, results: list, prefix: list) -> None:
    """Recursively Weave the first list into the second list and append 
    it to the results list.  The prefix list grows by an element with the 
    depth of the call stack.  Ultimately, either the first or second list will 
    be exhausted and the base case will append a result."""
    # base case
    if not first or not second:
        results.append(prefix + first + second)
        return

    # recursive case
    first_head, first_tail = first[0], first[1:]
    weave_lists(first_tail, second, results, prefix + [first_head])

    second_head, second_tail = second[0], second[1:]
    weave_lists(first, second_tail, results, prefix + [second_head])


def all_sequences(root: Node) -> list:
    """Splits the tree into three lists: prefix, left, and right."""
    if root is None:
        return []

    answer = []
    prefix = [root.value]
    left = all_sequences(root.left) or [[]]
    right = all_sequences(root.right) or [[]]

    # At a minimum, left and right must be a list containing an empty list
    # for the following nested loop
    for i in range(len(left)):
        for j in range(len(right)):
            weaved = []
            weave_lists(left[i], right[j], weaved, prefix)
        answer.extend(weaved)

    return answer


if __name__ == "__main__":
    t = bst(2)
    print(t)
    solution = all_sequences(t)
    for e, item in enumerate(solution):
        print(f"{e:03}: {item}")

样例输出

    __4
   /   \
  1     5
 / \     \
0   2     6

000: [4, 1, 0, 2, 5, 6]
001: [4, 1, 0, 5, 2, 6]
002: [4, 1, 0, 5, 6, 2]
003: [4, 1, 5, 0, 2, 6]
004: [4, 1, 5, 0, 6, 2]
005: [4, 1, 5, 6, 0, 2]
006: [4, 5, 1, 0, 2, 6]
007: [4, 5, 1, 0, 6, 2]
008: [4, 5, 1, 6, 0, 2]
009: [4, 5, 6, 1, 0, 2]
010: [4, 1, 2, 0, 5, 6]
011: [4, 1, 2, 5, 0, 6]
012: [4, 1, 2, 5, 6, 0]
013: [4, 1, 5, 2, 0, 6]
014: [4, 1, 5, 2, 6, 0]
015: [4, 1, 5, 6, 2, 0]
016: [4, 5, 1, 2, 0, 6]
017: [4, 5, 1, 2, 6, 0]
018: [4, 5, 1, 6, 2, 0]
019: [4, 5, 6, 1, 2, 0]

Process finished with exit code 0

7
这看起来非常像《程序员面试金典》中的解决方案复制品。如果是这样的话,我建议您引用来源 :) - Sanchit Batra
你应该提到你是从《程序员面试金典》第四章中借鉴了这段代码,而不是声称你自己编写的。 - Abhijit Sarkar

2
请注意,此问题实际上是关于树的拓扑排序:找到执行拓扑排序的所有可能方式。也就是说,我们不关心树的具体构建方式,重要的是元素始终作为叶子添加,从不改变现有节点的结构。输出的约束条件是节点永远不会在其祖先之前 - 将树视为经典的依赖图
但与一般DAG的拓扑排序不同,这里不需要引用计数,因为这是一棵树 - 引用计数始终为1或0。
下面是一个简单的Python实现:
    def all_toposorts_tree(sources, history):
        if not sources:
            print(history)
            return
        for t in sources:
            all_toposorts((sources - {t}) | {t.left, t.right} - {None}, history + [t.v])
    
    all_toposorts_tree({root}, [])

这是《程序员面试金典》第六版中的问题4.9。

2
我有一个更简短的解决方案。你觉得怎么样?
function printSequences(root){
    let combinations = [];

    function helper(node, comb, others){
        comb.push(node.values);

        if(node.left) others.push(node.left);
        if(node.right) others.push(node.right);

        if(others.length === 0){
            combinations.push(comb);
            return;
        }else{
            for(let i = 0; i<others.length; i++){
                helper(others[i], comb.slice(0), others.slice(0, i).concat(others.slice(i+1, others.length)));
            }
        }
    }

    helper(root, [], []);
    return combinations;
}

0

这是我的Python代码,用于生成同一二叉搜索树的所有元素/数字序列。对于算法逻辑,我参考了Gayle Laakmann Mcdowell的《Cracking the Coding Interview》一书。

from binarytree import  Node, bst, pprint

def wavelist_list(first, second, wave, prefix):
    if first:
       fl = len(first)
    else:
       fl = 0

    if second:       
        sl = len(second)
    else:
       sl = 0   
    if fl == 0 or sl == 0:
       tmp = list()
       tmp.extend(prefix)
       if first:
          tmp.extend(first)
       if second:   
          tmp.extend(second)
       wave.append(tmp)
       return

    if fl:
        fitem = first.pop(0)
        prefix.append(fitem)
        wavelist_list(first, second, wave, prefix)
        prefix.pop()
        first.insert(0, fitem)

    if sl:
        fitem = second.pop(0)
        prefix.append(fitem)
        wavelist_list(first, second, wave, prefix)
        prefix.pop()
        second.insert(0, fitem)        


def allsequences(root):
    result = list()
    if root == None:
       return result

    prefix = list()
    prefix.append(root.value)

    leftseq = allsequences(root.left)
    rightseq = allsequences(root.right)
    lseq = len(leftseq)
    rseq = len(rightseq)

    if lseq and rseq:
       for i in range(lseq):
          for j in range(rseq):
            wave = list()
            wavelist_list(leftseq[i], rightseq[j], wave, prefix)
            for k in range(len(wave)):
                result.append(wave[k])

    elif lseq:
      for i in range(lseq):
        wave = list()
        wavelist_list(leftseq[i], None, wave, prefix)
        for k in range(len(wave)):
            result.append(wave[k])

    elif rseq:
      for j in range(rseq):
        wave = list()
        wavelist_list(None, rightseq[j], wave, prefix)
        for k in range(len(wave)):
            result.append(wave[k])
   else:
       result.append(prefix) 

   return result



if __name__=="__main__":
    n = int(input("what is height of tree?"))
    my_bst = bst(n)
    pprint(my_bst)

    seq = allsequences(my_bst)
    print("All sequences")
    for i in range(len(seq)):
        print("set %d = " %(i+1), end="")
        print(seq[i])

 example output:
 what is height of tree?3

       ___12      
      /     \     
  __ 6       13   
 /   \        \  
 0     11       14
  \               
   2              


  All sequences
  set 1 = [12, 6, 0, 2, 11, 13, 14]
  set 2 = [12, 6, 0, 2, 13, 11, 14]
  set 3 = [12, 6, 0, 2, 13, 14, 11]
  set 4 = [12, 6, 0, 13, 2, 11, 14]
  set 5 = [12, 6, 0, 13, 2, 14, 11]
  set 6 = [12, 6, 0, 13, 14, 2, 11]
  set 7 = [12, 6, 13, 0, 2, 11, 14]
  set 8 = [12, 6, 13, 0, 2, 14, 11]
  set 9 = [12, 6, 13, 0, 14, 2, 11]
  set 10 = [12, 6, 13, 14, 0, 2, 11]
  set 11 = [12, 13, 6, 0, 2, 11, 14]
  set 12 = [12, 13, 6, 0, 2, 14, 11]
  set 13 = [12, 13, 6, 0, 14, 2, 11]
  set 14 = [12, 13, 6, 14, 0, 2, 11]
  set 15 = [12, 13, 14, 6, 0, 2, 11]
  set 16 = [12, 6, 0, 11, 2, 13, 14]
  set 17 = [12, 6, 0, 11, 13, 2, 14]
  set 18 = [12, 6, 0, 11, 13, 14, 2]
  set 19 = [12, 6, 0, 13, 11, 2, 14]
  set 20 = [12, 6, 0, 13, 11, 14, 2]
  set 21 = [12, 6, 0, 13, 14, 11, 2]
  set 22 = [12, 6, 13, 0, 11, 2, 14]
  set 23 = [12, 6, 13, 0, 11, 14, 2]
  set 24 = [12, 6, 13, 0, 14, 11, 2]
  set 25 = [12, 6, 13, 14, 0, 11, 2]
  set 26 = [12, 13, 6, 0, 11, 2, 14]
  set 27 = [12, 13, 6, 0, 11, 14, 2]
  set 28 = [12, 13, 6, 0, 14, 11, 2]
  set 29 = [12, 13, 6, 14, 0, 11, 2]
  set 30 = [12, 13, 14, 6, 0, 11, 2]
  set 31 = [12, 6, 11, 0, 2, 13, 14]
  set 32 = [12, 6, 11, 0, 13, 2, 14]
  set 33 = [12, 6, 11, 0, 13, 14, 2]
  set 34 = [12, 6, 11, 13, 0, 2, 14]
  set 35 = [12, 6, 11, 13, 0, 14, 2]
  set 36 = [12, 6, 11, 13, 14, 0, 2]
  set 37 = [12, 6, 13, 11, 0, 2, 14]
  set 38 = [12, 6, 13, 11, 0, 14, 2]
  set 39 = [12, 6, 13, 11, 14, 0, 2]
  set 40 = [12, 6, 13, 14, 11, 0, 2]
  set 41 = [12, 13, 6, 11, 0, 2, 14]
  set 42 = [12, 13, 6, 11, 0, 14, 2]
  set 43 = [12, 13, 6, 11, 14, 0, 2]
  set 44 = [12, 13, 6, 14, 11, 0, 2]
  set 45 = [12, 13, 14, 6, 11, 0, 2]

对于上述代码,我使用了二叉树包来创建给定长度的BST。 - Keshava Munegowda

0

public class Solution {
    ArrayList<LinkedList<Long>> result;
    /*Return the children of a node */
    ArrayList<TreeNode> getChilden(TreeNode parent) {
        ArrayList<TreeNode> child = new ArrayList<TreeNode>();
        if(parent.left != null) child.add(parent.left);
        if(parent.right != null) child.add(parent.right);
        return child;
    }
    /*Gets all the possible Compinations*/
    void getPermutations(ArrayList<TreeNode> permutations, LinkedList<Long> current) {
        if(permutations.size() == 0) {
            result.add(current);
            return;
        }
        int length = permutations.size();
        for(int i = 0; i < length; i++) {
            TreeNode node = permutations.get(i);
            permutations.remove(i);
            ArrayList<TreeNode> newPossibilities = new ArrayList<TreeNode>();
            newPossibilities.addAll(permutations);
            newPossibilities.addAll(getChilden(node));
            LinkedList<Long> newCur = new LinkedList<Long>();
            newCur.addAll(current);
            newCur.add(node.val);
            getPermutations(newPossibilities, newCur);
            permutations.add(i,node);
        }
    }

    /*This method returns a array of arrays which will lead to a given BST*/
    ArrayList<LinkedList<Long>> inputSequencesForBst(TreeNode node) { 
        result = new ArrayList<LinkedList<Long>>();
        if(node == null)
            return result;
        ArrayList<TreeNode> permutations = getChilden(node);
        LinkedList<Long> current = new LinkedList<Long>();
        current.add(node.val);
        getPermutations(permutations, current);
        return result;
    }
}

我的解决方案。完美运作。


您可能需要在代码之外添加一些解释。 - Zeina

0

这里是另一个基于简洁递归的易于理解的解决方案:

from binarytree import  Node, bst, pprint

def allsequences1(root):
    if not root:
        return None
    lt = allsequences1(root.left)
    rt = allsequences1(root.right)
    ret = []
    if not lt and not rt:
        ret.append([root])
    elif not rt:
        for one in lt:
            ret.append([root]+one)
    elif not lt:
        for two in rt:
            ret.append([root]+two)
    else:
        for one in lt:
            for two in rt:
                ret.append([root]+one+two)
                ret.append([root]+two+one)
    return ret



if __name__=="__main__":
    n = int(input("what is height of tree?"))
    my_bst = bst(n)
    pprint(my_bst)
    seg = allsequences1(my_bst)
    print("All sequences ..1")
    for i in range(len(seq)):
        print("set %d = " %(i+1), end="")
        print(seq[i])

请您能否添加一些注释或者简单描述一下您的意图?谢谢 :) - Vishnu

0

首先让我们观察创建相同二叉搜索树必须遵循的规则。这里唯一足够的规则是在插入左右子节点之前插入父节点。因为,如果我们可以保证对于某个节点(我们有兴趣插入的节点),所有父节点(包括祖父节点)都已经插入,但它的任何一个子节点都没有被插入,那么该节点将找到适当的位置进行插入。

根据这个观察结果,我们可以编写回溯算法来生成产生相同二叉搜索树的所有序列。

active_list = {root}
current_order = {}
result ={{}}
backtrack():
     if(len(current_order) == total_node):
         result.push(current_order)
         return;
     for(node in active_list):
          current_order.push(node.value)

          if node.left : 
               active_list.push(node.left)
          if node.right: 
               active_list.push(node.right)

          active_list.remove(node)
          backtrack()
          active_list.push(node)

          if node.left : 
               active_list.remove(node.left)
          if node.right: 
               active_list.remove(node.right)
          current_order.remove(node.val)

这不是一个可用的实现,仅用于说明目的。


0

这是我的 Python 解法,其中包含大量的解释。

我们通过从左到右为每个位置选择一个节点来构建每个数组。我们将节点值添加到路径中,并将其子节点(如果有)添加到可能性列表中,然后继续递归。当没有更多选择时,我们就有了一个候选数组。要生成其余的数组,我们会回溯直到可以做出不同的选择,然后再次递归。

关键是使用适合持有可能性的数据结构。列表可行,但需要在回溯时将节点放回前一个位置(顺序很重要,因为我们已经添加了节点的子节点,必须在节点之后访问它们)。从列表中插入和删除需要线性时间。集合不起作用,因为它不维护顺序。字典最好,因为Python字典记住插入顺序,所有操作都以常数时间运行。

def bst_seq(root: TreeNode) -> list[list[int]]:
    def _loop(choices: MutableMapping[TreeNode, bool], path: list[int], result: list[list[int]]) -> None:
        if not choices:
            result.append([*path])
        else:
            # Take a snapshot of the keys to avoid concurrent modification exception
            for choice in list(choices.keys()):
                del choices[choice]
                children = list(filter(None, [choice.left, choice.right]))
                for child in children:
                    choices[child] = False
                path.append(choice.val)
                _loop(choices, path, result)
                path.pop()
                choices[choice] = False
                for child in children:
                    del choices[child]

    result = []
    _loop({root: False}, [], result)
    return result

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接