Python heapq heappush: 将多个元素的数组的真值是不明确的。使用a.any()或a.all()。

8

我在使用 heapq 库时遇到了一个错误 -- 特别是 heappush 函数。下面的错误代码并没有给我提供帮助。

(Pdb) heapq.heappush(priority_queue, (f, depth, child_node_puzzle_state))
*** ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

这里是导致问题的代码片段...
h = compute_heuristic(child_node_puzzle_state, solved_puzzle)
depth = current_node[1] + 1
f = h + depth
heapq.heappush(priority_queue, [f, depth, child_node_puzzle_state])

需要注意的是,hdepth都是整数类型(int),而child_node_puzzle_state是一个numpy数组。请查看一些调试代码...

(Pdb) child_node_puzzle_state
array([[  5.,   4.,  18.,  15.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         99.],
       [ 99.,  10.,   6.,  14.,  12.,  20.,   0.,   0.,   0.,   0.,  99.,
         99.],
       [ 99.,  99.,  11.,  19.,  17.,  16.,   8.,   0.,   0.,  99.,  99.,
         99.],
       [ 99.,  99.,  99.,   2.,   3.,   0.,   0.,   0.,  99.,  99.,  99.,
         99.],
       [ 99.,  99.,  99.,  99.,   1.,  21.,   0.,  99.,  99.,  99.,  99.,
         99.],
       [ 99.,  99.,  99.,  99.,  99.,   9.,  13.,   7.,   0.,   0.,   0.,
          0.]])
(Pdb) child_node_puzzle_state.dtype
dtype('float64')
(Pdb) p h
3
(Pdb) depth
2
(Pdb) f
5
(Pdb) priority_queue
[(5, 2, array([[  9.,  15.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         99.],
       [ 99.,  10.,   6.,  14.,   5.,   4.,  18.,   0.,   0.,   0.,  99.,
         99.],
       [ 99.,  99.,  11.,  19.,  17.,  12.,  20.,   8.,   0.,  99.,  99.,
         99.],
       [ 99.,  99.,  99.,  16.,   3.,   0.,   0.,   0.,  99.,  99.,  99.,
         99.],
       [ 99.,  99.,  99.,  99.,   2.,   0.,   0.,  99.,  99.,  99.,  99.,
         99.],
       [ 99.,  99.,  99.,  99.,  99.,   1.,  21.,  13.,   7.,   0.,   0.,


...

(Pdb) len(priority_queue)
9

我无法理解的是,如果我更改了一点点内容,它就可以正常工作——但这在语义上是错误的。这就是更改的内容...

h = compute_heuristic(child_node_puzzle_state, solved_puzzle)
depth = current_node[1] + 1
heapq.heappush(priority_queue, (h, depth, child_node_puzzle_state))

你注意到了区别吗?我没有计算f = h + depth,而是只使用了h。这样就可以奇迹般地工作了吗?

这不可能是大小的问题,因为我在调试中已经证明了...

(Pdb) len(priority_queue)
9

这对我来说真的没有意义,所以我会包含更多的代码。首先,这里是计算h所需的一切,没有任何奇怪的事情发生,所以我真的怀疑这是问题所在。所有函数都返回整数(虽然它们使用numpy数组)...

def tubes_aligned(puzzle_state):

    current_index = 3 #test for index 3
    blue_tube = puzzle_state[3,:]
    len_of_top_tube = len(blue_tube[blue_tube < 99]) - 3
    correct_index = 6 - len_of_top_tube

    found = False
    distance = 3
    for i in range(3):
        if i == correct_index:
            distance = current_index - i
            found = True

    if not found:
        for i in range(5,2,-1):
            if i == correct_index:
                distance = i - current_index

    return distance

def balls_in_top_half(puzzle_state):

    for i in range(6):
        full_tube = puzzle_state[i,:]
        num_balls = full_tube[full_tube < 99]
        num_balls = len(num_balls[num_balls > 0])
        if (6 - i - num_balls) != 0:
            return 1

    return 0

def balls_in_correct_place(puzzle_state, solved_puzzle):
    if is_solved(puzzle_state, solved_puzzle):
        return 0
    else:
        return 1

def compute_heuristic(puzzle_state, solved_puzzle):
    # print "computing heuristic"
    # heuristic (sum all three):
    #     1. how many tubes is the puzzle state from tubes being aligned -- max is 3
    #     2. is there balls in the top portion? 1 -- yes || 0 -- no
    #     3. are there balls in the wrong place in the bottom half? 1 -- yes || 0 -- no
    part_1 = tubes_aligned(puzzle_state)
    part_2 = balls_in_top_half(puzzle_state)
    part_3 = balls_in_correct_place(puzzle_state, solved_puzzle)
    return part_1 + part_2 + part_3

我曾经遇到过这个错误,因为 numpy.any()numpy.all() 遮盖了内置的 any()all()。但是我发现,这个错误可能有很多种情况。如果您发布一个 MVCE,我们或许可以帮助您。 - moooeeeep
@moooeeeep 嗯...为了解决这个问题,应该怎么做呢?而且,在我的情况下将 h 改为 f 会导致什么错误? - Kendall Weihe
当一个布尔数组(具有多个元素)在期望标量True/False的上下文中使用时,就会出现此错误。无论细节如何,heapq比较本质上是标量的。因此,在heapq上放置一个数组会导致引发此错误。 - hpaulj
1个回答

23

heapq.heappush 会将一个数组与堆中的其它数组进行比较,如果你要推送的元组前面的元素相等,则进行比较。

这是一个纯Python实现的heappush():

def heappush(heap, item):
    """Push item onto heap, maintaining the heap invariant."""
    heap.append(item)
    _siftdown(heap, 0, len(heap)-1)

def _siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem

实际的实现将使用C语言完成,这就是为什么没有更深入的追溯就会出现错误的原因。

请注意newitem < parent的比较,正是这个比较导致了异常,因为numpy array对象将逐个元素进行比较并生成一个由True和False结果组成的布尔数组。如果在堆上存在fdepth相等的状态,那么此比较必须比较这些数组:

>>> import numpy
>>> t1 = (5, 2, numpy.array([9.,  15.]))
>>> t2 = (5, 2, numpy.array([10.,  15.]))
>>> t1 < t2
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

当您更改元组中第一个位置的值时,对于您来说,问题“消失”了,使得前两个值与堆中已经存在的值再次唯一。但实际上,这并没有解决潜在的问题。

您可以通过在数组之前插入唯一计数(使用itertools.count())来避免此问题:

from itertools import count

# a global
tiebreaker = count()

# each time you push
heapq.heappush(priority_queue, (f, depth, next(tiebreaker), child_node_puzzle_state))
计数器确保您的元组的前三个元素始终是唯一的。这还意味着,如果堆中存在与启发式分数和深度上已有状态匹配的后续添加,则此类状态将在旧状态之前排序。如果您想要反转该关系,可以使用count(step=-1)

这个答案在这里有详细记录:https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes - M . Franklin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接