我在使用 heapq
库时遇到了一个错误 -- 特别是 heappush
函数。下面的错误代码并没有给我提供帮助。
(Pdb) heapq.heappush(priority_queue, (f, depth, child_node_puzzle_state))
*** ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
这里是导致问题的代码片段...
h = compute_heuristic(child_node_puzzle_state, solved_puzzle)
depth = current_node[1] + 1
f = h + depth
heapq.heappush(priority_queue, [f, depth, child_node_puzzle_state])
需要注意的是,h
和depth
都是整数类型(int
),而child_node_puzzle_state
是一个numpy数组。请查看一些调试代码...
(Pdb) child_node_puzzle_state
array([[ 5., 4., 18., 15., 0., 0., 0., 0., 0., 0., 0.,
99.],
[ 99., 10., 6., 14., 12., 20., 0., 0., 0., 0., 99.,
99.],
[ 99., 99., 11., 19., 17., 16., 8., 0., 0., 99., 99.,
99.],
[ 99., 99., 99., 2., 3., 0., 0., 0., 99., 99., 99.,
99.],
[ 99., 99., 99., 99., 1., 21., 0., 99., 99., 99., 99.,
99.],
[ 99., 99., 99., 99., 99., 9., 13., 7., 0., 0., 0.,
0.]])
(Pdb) child_node_puzzle_state.dtype
dtype('float64')
(Pdb) p h
3
(Pdb) depth
2
(Pdb) f
5
(Pdb) priority_queue
[(5, 2, array([[ 9., 15., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
99.],
[ 99., 10., 6., 14., 5., 4., 18., 0., 0., 0., 99.,
99.],
[ 99., 99., 11., 19., 17., 12., 20., 8., 0., 99., 99.,
99.],
[ 99., 99., 99., 16., 3., 0., 0., 0., 99., 99., 99.,
99.],
[ 99., 99., 99., 99., 2., 0., 0., 99., 99., 99., 99.,
99.],
[ 99., 99., 99., 99., 99., 1., 21., 13., 7., 0., 0.,
...
(Pdb) len(priority_queue)
9
我无法理解的是,如果我更改了一点点内容,它就可以正常工作——但这在语义上是错误的。这就是更改的内容...
h = compute_heuristic(child_node_puzzle_state, solved_puzzle)
depth = current_node[1] + 1
heapq.heappush(priority_queue, (h, depth, child_node_puzzle_state))
你注意到了区别吗?我没有计算f = h + depth
,而是只使用了h
。这样就可以奇迹般地工作了吗?
这不可能是大小的问题,因为我在调试中已经证明了...
(Pdb) len(priority_queue)
9
这对我来说真的没有意义,所以我会包含更多的代码。首先,这里是计算h
所需的一切,没有任何奇怪的事情发生,所以我真的怀疑这是问题所在。所有函数都返回整数(虽然它们使用numpy数组)...
def tubes_aligned(puzzle_state):
current_index = 3 #test for index 3
blue_tube = puzzle_state[3,:]
len_of_top_tube = len(blue_tube[blue_tube < 99]) - 3
correct_index = 6 - len_of_top_tube
found = False
distance = 3
for i in range(3):
if i == correct_index:
distance = current_index - i
found = True
if not found:
for i in range(5,2,-1):
if i == correct_index:
distance = i - current_index
return distance
def balls_in_top_half(puzzle_state):
for i in range(6):
full_tube = puzzle_state[i,:]
num_balls = full_tube[full_tube < 99]
num_balls = len(num_balls[num_balls > 0])
if (6 - i - num_balls) != 0:
return 1
return 0
def balls_in_correct_place(puzzle_state, solved_puzzle):
if is_solved(puzzle_state, solved_puzzle):
return 0
else:
return 1
def compute_heuristic(puzzle_state, solved_puzzle):
# print "computing heuristic"
# heuristic (sum all three):
# 1. how many tubes is the puzzle state from tubes being aligned -- max is 3
# 2. is there balls in the top portion? 1 -- yes || 0 -- no
# 3. are there balls in the wrong place in the bottom half? 1 -- yes || 0 -- no
part_1 = tubes_aligned(puzzle_state)
part_2 = balls_in_top_half(puzzle_state)
part_3 = balls_in_correct_place(puzzle_state, solved_puzzle)
return part_1 + part_2 + part_3
numpy.any()
和numpy.all()
遮盖了内置的any()
和all()
。但是我发现,这个错误可能有很多种情况。如果您发布一个 MVCE,我们或许可以帮助您。 - moooeeeeph
改为f
会导致什么错误? - Kendall Weiheheapq
比较本质上是标量的。因此,在heapq
上放置一个数组会导致引发此错误。 - hpaulj