N皇后问题II使用回溯算法很慢。

5
n皇后问题是将n个皇后放置在一个n x n的棋盘上,使得任意两个皇后都不会互相攻击的问题。
给定一个整数n,返回n皇后问题的不同解决方案数量。

https://leetcode.com/problems/n-queens-ii/

我的解决方案:

class Solution:
    def totalNQueens(self, n: int) -> int:
        def genRestricted(restricted, r, c):
            restricted = set(restricted)
            for row in range(n): restricted.add((row, c))
            for col in range(n): restricted.add((r, col))
            movements = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
            for movement in movements:
                row, col = r, c
                while 0 <= row < n and 0 <= col < n:
                    restricted.add((row, col))
                    row += movement[0]
                    col += movement[1]
            return restricted

        def gen(row, col, curCount, restricted):
            count, total_count = curCount, 0

            for r in range(row, n):
                for c in range(col, n):
                    if (r, c) not in restricted:
                        count += 1
                        if count == n: total_count += 1
                        total_count += gen(row + 1, 0, count, genRestricted(restricted, r, c))
                        count -= 1

            return total_count

        return gen(0, 0, 0, set())

它在n=8时失败。我无法弄清原因,也不知如何少进行迭代。看来我已经在尽可能少的迭代次数下进行了操作。
6个回答

5

restricted集合似乎在时间和空间上都比较浪费。在成功递归的末尾,n层深度时它会增长到n^2大小,这将使总复杂度变为O(n^3)。而且它并不是真正需要的。通过查看已放置的皇后,检查该方格是否可用更加容易(请原谅这个象棋术语; file代表垂直方向,rank代表水平方向):

def square_is_safe(file, rank, queens_placed):
    for queen_rank, queen_file in enumerate(queens_placed):
        if queen_file == file:                      # vertical attack
            return false
        if queen_file - file == queen_rank - rank:  # diagonal attack
            return false
        if queen_file - file == rank - queen_rank:  # anti-diagonal attack
            return false
    return true

用于

def place_queen_at_rank(queens_placed, rank):
    if rank == n:
        total_count += 1
        return

    for file in range(0, n):
        if square_is_safe(file, rank, queens_placed):
            queens_placed.append(file)
            place_queen_at_rank(queens_placed, rank + 1)

    queens_placed.pop()

还有很多优化的空间。例如,您可能希望特别处理第一个排名:由于对称性,您只需要检查其中一半(将执行时间缩短了2倍)。


2

只需要一次修改(去掉gen中的r循环),就可以使你的解决方案AC。

主要原因是你的gen函数有一个叫做row的参数,它将使用row + 1调用自身,所以没有必要使用for r in range(row, n):来进行迭代。这是不必要的。只需去除它,你的解决方案就很可接受了。(我们需要在嵌套调用之前添加else


以下是结果:

1 1      1.8358230590820312e-05
2 0      5.7697296142578125e-05
3 0      0.00036835670471191406
4 2      0.0021448135375976562
5 10     0.02212214469909668
6 4      0.23602914810180664
7 40     3.0731561183929443

修改后:

1 1      1.6450881958007812e-05
2 0      3.1948089599609375e-05
3 0      0.0001366138458251953
4 2      0.0002281665802001953
5 10     0.0008234977722167969
6 4      0.0028502941131591797
7 40     0.01242375373840332
8 92     0.05443763732910156
9 352    0.2279810905456543

对于n = 7的情况,它只使用了原始版本0.4%的时间,而n = 8则完全可以运行。

class Solution:
    def totalNQueens(self, n: int) -> int:
        def genRestricted(restricted, r, c):
            restricted = set(restricted)
            for row in range(n): restricted.add((row, c))
            for col in range(n): restricted.add((r, col))
            movements = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
            for movement in movements:
                row, col = r, c
                while 0 <= row < n and 0 <= col < n:
                    restricted.add((row, col))
                    row += movement[0]
                    col += movement[1]
            return restricted

        def gen(row, col, curCount, restricted):
            count, total_count = curCount, 0

            for c in range(col, n):
                if (row, c) not in restricted:
                    count += 1
                    if count == n: total_count += 1
                    else: total_count += gen(row + 1, 0, count, genRestricted(restricted, row, c))
                    count -= 1

            return total_count

        return gen(0, 0, 0, set())
if __name__ == '__main__':
    import time
    s = Solution()
    for i in range(1, 8):
        t0 = time.time()
        print(i, s.totalNQueens(i), '\t', time.time() - t0)

当然,还有其他的增强措施可以采取。但这是最大的一个。

例如,在添加每个点之后,您可以更新和创建一个新的受限/禁止点。 顺便说一下,我不同意@user58697对于restricted的看法,它在您的解决方案中是必要的,因为您需要克隆和更新以获得一个新的点来避免在递归调用循环中恢复它。


顺便说一下,以下是我的解决方案,仅供参考:

class Solution:
    def solveNQueens_n(self, n): #: int) -> List[List[str]]:
        cols = [-1] * n # index means row index
        self.res = 0
        usedCols = set() # this and cols can avoid vertical and horizontal conflict

        def dfs(r): # current row to fill in
            def valid(c):
                for r0 in range(r):
                    # (r0, c0), (r1, c1) in the (back-)diagonal, |r1 - r0| = |c1 - c0|
                    if abs(c - cols[r0]) == abs(r - r0):
                        return False
                return True
            if r == n: # valid answer
                self.res += 1
                return
            for c in range(n):
                if c not in usedCols and valid(c):
                    usedCols.add(c)
                    cols[r] = c
                    dfs(r + 1)
                    usedCols.remove(c)
                    cols[r] = -1

        dfs(0)
        return self.res

只需运行你的代码-它能达到60%,相当不错。分析也很好。 - Daniel Hao
谢谢丹尼尔。看起来 @Good_Evening 没有时间查看(这么多)答案,这个问题仍然是开放的。 - Yang Liu

2
对于n ≤ 9(在链接的谜题中的限制),枚举所有合法的车的位置并验证是否存在攻击对角线的移动是足够的。
import itertools


def is_valid(ranks):
    return not any(
        abs(f1 - f2) == abs(r1 - r2)
        for f1, r1 in enumerate(ranks)
        for f2, r2 in enumerate(ranks[:f1])
    )


def count_valid(n):
    return sum(map(is_valid, itertools.permutations(range(n))))


print(*(count_valid(i) for i in range(1, 10)), sep=",")

2
在这种问题中,你首先要关注算法,而不是代码。
接下来,我将重点介绍算法,并用C++举例说明。
一个主要问题是能够快速检测一个给定的位置是否已经被现有皇后控制。
一种简单的可能性是对对角线(从0到2N-1)进行索引,并在数组中跟踪相应的对角线、反对角线或列是否已经受到控制。任何一种索引对角线或反对角线的方法都可以完成任务。对于给定的(row, column)点,我使用以下方式:
diagonal index = row + column
antidiagonal index = n-1 + col - row

此外,我使用了一个简单的对称性:只需要计算从0到n/2-1(如果n为奇数,则为n/2)的行索引的可能性数量。当然,可以通过使用其他对称性来稍微加快速度。但是,就目前而言,对于小于或等于9的n值,它看起来已经足够快了。
结果:
2 : 0 time : 0.001 ms
3 : 0 time : 0.001 ms
4 : 2 time : 0.001 ms
5 : 10 time : 0.002 ms
6 : 4 time : 0.004 ms
7 : 40 time : 0.015 ms
8 : 92 time : 0.05 ms
9 : 352 time : 0.241 ms
10 : 724 time : 0.988 ms
11 : 2680 time : 5.55 ms
12 : 14200 time : 31.397 ms
13 : 73712 time : 188.12 ms
14 : 365596 time : 1046.43 ms

以下是 C++ 代码。由于代码非常简单,您应该可以轻松将其转换为 Python。


#include <iostream>
#include <chrono>

constexpr int N_MAX = 14;
constexpr int N_DIAG = 2*N_MAX + 1;

class Solution {
public:
    int n;
    int Col[N_MAX] = {0};
    int Diag[N_DIAG] = {0};
    int AntiDiag[N_DIAG] = {0};
    
    int totalNQueens(int n1) {
        n = n1;
        if (n <= 1) return n;
        int count = 0;
        for (int col = 0; col < n/2; ++col) {
            count += sum_from (0, col);
        }
        count *= 2;
        if (n%2) count += sum_from (0, n/2);
        return count;
    }
    
    int sum_from (int row, int col) {
        if (Col[col]) return 0;
        int diag = row + col;
        if (Diag[diag]) return 0;
        int antidiag = n-1 + col - row;
        if(AntiDiag[antidiag]) return 0;
        if (row == n-1) return 1;
        int count = 0;
        Col[col] = 1;
        Diag[diag] = 1;
        AntiDiag[antidiag] = 1;
        for (int k = 0; k < n; ++k) {
            count += sum_from (row+1, k);
        }
        Col[col] = 0;
        Diag[diag] = 0;
        AntiDiag[antidiag] = 0;
        return count;
    }
};


int main () {
    int n = 1;
    while (n++ < N_MAX) {
        auto start = std::chrono::high_resolution_clock::now();
        Solution Sol;
        std::cout << n << " : " << Sol.totalNQueens (n) << " time : ";
        auto diff = std::chrono::high_resolution_clock::now() - start;
        auto duration = std::chrono::duration_cast<std::chrono::microseconds>(diff).count();
        std::cout << double(duration)/1000 << " ms" << std::endl;
    }
    return 0;
}


“通过使用其他对称性,肯定可以稍微加速它。”- 你能否解释一下你所说的其他“对称性”的细节?谢谢。 - abhimanyue
1
@abhimanyue 有许多对称性。从一个解决方案中,您可以通过旋转游戏或考虑水平线中点的反射对称性或对角线对称性获得其他解决方案...然而,在实践中似乎很难使用这些对称性。另一个问题是解的数量,即复杂度随着大小n的增加而迅速增加。将时间减半并不能帮助太多... - Damien

2
你可以通过每行只放置一个皇后来避免检查水平冲突。这也使你能够通过仅标记后续行来减小对角线冲突矩阵的大小。使用简单的布尔标志列表来检查列冲突也是一种时间节省的方法(而不是在矩阵中标记多个条目)。
以下是一个解决方案生成器的示例:
def genNQueens(size=8):
    # setup queen coverage from each position {position:set of positions}
    reach = { (r,c):[] for r in range(size) for c in range(0,size) }
    for R in range(size):
        for C in range(size):
            for h in (1,-1): # diagonals on next rows
                reach[R,C].extend((R+i,C+h*i) for i in range(1,size))
            reach[R,C] = [P for P in reach[R,C] if P in reach]
    reach.update({(r,-1):[] for r in range(size)}) # for unplaced rows

    # place 1 queen on each row, with backtracking
    cols     = [-1]*size            # column of each queen (start unplaced)
    usedCols = [False]*(size+1)     # column conflict detection
    usedDiag = [[0]*(size+1) for _ in range(size+1)] # for diagonal conflicts
    r        = 0
    while r >= 0:
        usedCols[cols[r]] = False
        for ur,uc in reach[r,cols[r]]: usedDiag[ur][uc] -= 1
        cols[r] = next((c for c in range(cols[r]+1,size)
                        if not usedCols[c] and not usedDiag[r][c]),-1)
        usedCols[cols[r]] = True
        for ur,uc in reach[r,cols[r]]: usedDiag[ur][uc] += 1
        r += 1 if cols[r]>=0 else -1   # progress or backtrack
        if r<size : continue           # continue until all rows placed
        yield [*enumerate(cols)]       # return result
        r -= 1                         # backtrack to find more
        

输出:

from timeit import timeit
for n in range(3,13):
    t = timeit(lambda:sum(1 for _ in genNQueens(n)), number=1)
    c = sum(1 for _ in genNQueens(n))
    print(f"solutions for {n}x{n}:", c, "time:",f"{t:.4g}")    

solutions for 3x3: 0 time: 0.000108
solutions for 4x4: 2 time: 0.0002044
solutions for 5x5: 10 time: 0.0004365
solutions for 6x6: 4 time: 0.0008741
solutions for 7x7: 40 time: 0.003386
solutions for 8x8: 92 time: 0.009881
solutions for 9x9: 352 time: 0.03402
solutions for 10x10: 724 time: 0.1228
solutions for 11x11: 2680 time: 0.5707
solutions for 12x12: 14200 time: 2.77
    

0

好的,我错过的一件事是每一行必须有一个皇后。这是非常重要的观察。gen方法必须像这样进行修改:

    def gen(row, col, curCount, restricted):
        if row == n: return 0
        
        count, total_count = curCount, 0
        
        for c in range(col, n):
            if (row, c) not in restricted:
                if count + 1 == n: total_count += 1
                total_count += gen(row + 1, 0, count + 1, genRestricted(restricted, row, c))
                    
        return total_count

它只能击败约20%的提交,因此它根本不完美。远非如此。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接