生成最优二叉搜索树(Cormen)

7

我正在阅读Cormen et al.的《算法导论》(第3版)(PDF),第15.4节关于最优二叉搜索树,但是在Python中实现optimal_bst函数的伪代码时遇到了一些问题。

这是我尝试应用最优二叉搜索树的示例:

enter image description here

让我们定义e[i,j]为包含从标记为ij的键的最优二叉搜索树的预期搜索成本。最终,我们希望计算e[1,n],其中n是键的数量(此示例中为5)。最终递归公式为:

enter image description here

以下伪代码应该被实现:

enter image description here

注意,伪代码交替使用基于1和基于0的索引,而Python仅使用后者。因此,我在实现伪代码时遇到了麻烦。以下是我目前的进展:
import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
    for i in range(n-l+1):
        j = i + l
        e[i, j] = np.inf
        w[i, j] = w[i, j-1] + p[j-1] + q[j]
        for r in range(i, j+1):
            t = e[i-1, r-1] + e[r, j] + w[i-1, j]
            if t < e[i-1, j]:
                e[i-1, j] = t
                root[i-1, j] = r

print(w)
print(e)

然而,如果我运行这个程序,权重w会被正确计算,但期望的搜索值e仍然停留在它们初始化的值上。
[[ 0.05  0.3   0.45  0.55  0.7   1.  ]
 [ 0.    0.1   0.25  0.35  0.5   0.8 ]
 [ 0.    0.    0.05  0.15  0.3   0.6 ]
 [ 0.    0.    0.    0.05  0.2   0.5 ]
 [ 0.    0.    0.    0.    0.05  0.35]
 [ 0.    0.    0.    0.    0.    0.1 ]]
[[ 0.05   inf   inf   inf   inf   inf]
 [ 0.    0.1    inf   inf   inf   inf]
 [ 0.    0.    0.05   inf   inf   inf]
 [ 0.    0.    0.    0.05   inf   inf]
 [ 0.    0.    0.    0.    0.05   inf]
 [ 0.    0.    0.    0.    0.    0.1 ]]

我希望ewroot的值如下所示:

enter image description here

我已经调试了几个小时,但仍然卡住了。有人能指出上面的Python代码有什么问题吗?

3个回答

0

我觉得你在索引上犯了一个错误。我无法按预期工作,但以下代码应该给你一个方向指示(可能有一个偏移量错误):

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

def get2(m, i, j):
    return m[i - 1, j - 1]


def set2(m, i, j, v):
    m[i - 1, j - 1] = v


def get1(m, i):
    return m[i - 1]


def set1(m, i, v):
    m[i - 1] = v


e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
    for i in range(n - l + 2):
        j = i + l - 1
        set2(e, i, j, np.inf)
        set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
        for r in range(i, j + 1):
            t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
            if t < get2(e, i, j):
                set2(e, i, j, t)
                set2(root, i, j, r)

print(w)
print(e)

结果:

[[ 0.2   0.4   0.5   0.65  0.9   0.  ]
 [ 0.    0.2   0.3   0.45  0.7   0.  ]
 [ 0.    0.    0.1   0.25  0.5   0.  ]
 [ 0.    0.    0.    0.15  0.4   0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.5   0.7   0.8   0.95  0.    0.3 ]]
[[ 0.2   0.6   0.8   1.2   1.95  0.  ]
 [ 0.    0.2   0.4   0.8   1.35  0.  ]
 [ 0.    0.    0.1   0.35  0.85  0.  ]
 [ 0.    0.    0.    0.15  0.55  0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.7   1.2   1.5   2.    0.    0.3 ]]

0
最终我使用了 pandasSeriesDataFrame 对象,这些对象使用自定义的 indexcolumns 进行初始化,以强制数组具有与伪代码中相同的索引。之后,伪代码几乎可以直接复制粘贴:

import numpy as np
import pandas as pd

P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)

p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)

e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))

for l in range(1, n+1):
    for i in range(1, n-l+2):
        j = i+l-1
        e.set_value(i, j, np.inf)
        w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
        for r in range(i, j+1):
            t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
            if t < e.get_value(i, j):
                e.set_value(i, j, t)
                root.set_value(i, j, r)

print(e)
print(w)
print(root)

这将产生预期的结果:

      0     1     2     3     4     5
1  0.05  0.45  0.90  1.25  1.75  2.75
2  0.00  0.10  0.40  0.70  1.20  2.00
3  0.00  0.00  0.05  0.25  0.60  1.30
4  0.00  0.00  0.00  0.05  0.30  0.90
5  0.00  0.00  0.00  0.00  0.05  0.50
6  0.00  0.00  0.00  0.00  0.00  0.10
      0    1     2     3     4     5
1  0.05  0.3  0.45  0.55  0.70  1.00
2  0.00  0.1  0.25  0.35  0.50  0.80
3  0.00  0.0  0.05  0.15  0.30  0.60
4  0.00  0.0  0.00  0.05  0.20  0.50
5  0.00  0.0  0.00  0.00  0.05  0.35
6  0.00  0.0  0.00  0.00  0.00  0.10
     1    2    3    4    5
1  1.0  1.0  2.0  2.0  2.0
2  0.0  2.0  2.0  2.0  4.0
3  0.0  0.0  3.0  4.0  5.0
4  0.0  0.0  0.0  4.0  5.0
5  0.0  0.0  0.0  0.0  5.0

不过,我仍然对使用Numpy数组的解决方案感兴趣,因为这似乎更加优雅。


0

Kurt, 感谢您的帖子!这是我找到的唯一可行的解决方案。我花了很多时间来处理索引。这是我的numpy数组实现。

import numpy as np
import math

def optimalBST(p,q,n):

    e = np.zeros((n+1)**2).reshape(n+1,n+1)
    w = np.zeros((n+1)**2).reshape(n+1,n+1)
    root = np.zeros((n+1)**2).reshape(n+1,n+1)

    # Initialization
    for i in range(n+1):
        e[i,i] = q[i]
        w[i,i] = q[i]
    for i in range(0,n):
        root[i,i] = i+1

    for l in range(1,n+1):
        for i in range(0, n-l+1):
            j = i+l
            min_ = math.inf
            w[i,j] = w[i,j-1] + p[j] + q[j]
            for r in range(i,j):
                t = e[i, r-1+1] + e[r+1,j] +  w[i,j]
                if t < min_:
                    min_ = t                
                    e[i, j] = t
                    root[i, j-1] = r+1

    root_pruned = np.delete(np.delete(root, n, 1), n, 0)        # Trim last col & row.

    print("------ e -------")
    print(e)
    print("------ w -------")
    print(w)
    print("----- root -----")
    print(root_pruned)

def main():

    p = [0,.15,.1,.05,.1,.2]
    q = [.05,.1,.05,.05,.05,.1]
    n = len(p)-1

    optimalBST(p,q,n)

if __name__ == '__main__':
    main()

输出:

------ e -------
[[0.05 0.45 0.9  1.25 1.75 2.75]
 [0.   0.1  0.4  0.7  1.2  2.  ]
 [0.   0.   0.05 0.25 0.6  1.3 ]
 [0.   0.   0.   0.05 0.3  0.9 ]
 [0.   0.   0.   0.   0.05 0.5 ]
 [0.   0.   0.   0.   0.   0.1 ]]
------ w -------
[[0.05 0.3  0.45 0.55 0.7  1.  ]
 [0.   0.1  0.25 0.35 0.5  0.8 ]
 [0.   0.   0.05 0.15 0.3  0.6 ]
 [0.   0.   0.   0.05 0.2  0.5 ]
 [0.   0.   0.   0.   0.05 0.35]
 [0.   0.   0.   0.   0.   0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
 [0. 2. 2. 2. 4.]
 [0. 0. 3. 4. 5.]
 [0. 0. 0. 4. 5.]
 [0. 0. 0. 0. 5.]]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接