Python默认字典中给定键的最小值

4

我有一个以列表为值和元组为键的defaultdict(代码中的ddict)。我想要找到给定一组键的值的最小值和最大值。这些键以numpy数组形式给出。该numpy数组是一个包含键的3D数组。每个3D数组的行是我们需要查找minmax的键块。也就是说,对于每一行,我们取相应的2D数组条目,并获取对应于这些条目的值,然后在这些值上找到minmax。我需要对所有3D数组的行执行此操作。

from operator import itemgetter
import numpy as np

ddict =  {(1.0, 1.0): [1,2,3,4], (1.0, 2.5): [2,3,4,5], (1.0, 3.75): [], (1.5, 1.0): [8,9,10], (1.5, 2.5): [2,6,8,19,1,31], (1.5,3.75): [4]}
indA = np.array([ [ [( 1.0, 1.0), ( 1.0, 3.75)], [(1.5,1.0), (1.5,3.75)] ], [ [(1.0, 2.5), (1.5,1.0)], [(1.5, 2.5), (1.5,3.75)] ] ], dtype='float16,float16')

mins = min(ddict, key=itemgetter(*[tuple(i) for b in indA for i in b.flatten()]))
maxs = max(ddict, key=itemgetter(*[tuple(i) for b in indA for i in b.flatten()]))

我尝试使用上述代码获取以下代码的输出: min1 = min([1,2,3,4,8,9,10,4]) & min2 = min([2,3,4,5,8,9,10,2,6,8,19,1,31,4])max1= max([1,2,3,4,8,9,10,4]) & max2 = max([2,3,4,5,8,9,10,2,6,8,19,1,31,4]) 我想计算 numpy 数组中每个 2D 数组的最小值和最大值。有什么解决方法?为什么我的代码不起作用?它会给我一个错误,TypeError: tuple indices must be integers or slices, not tuple

1
我想找到给定键集的最小值和最大值 - 你的意思是说,对于每个元组集合(似乎更像是列表),你想要知道与这些元组键相关联的所有列表中的最小值?您能解释一下它与您所期望得到的输出有什么关系吗?(此外,您的示例输出可能缺少括号,但由于我不理解您的输出应该表示什么,因此我没有自己进行更正) - Grismar
@Grismar 我编辑了问题。我的意思是,字典(元组)的键是2D数组的条目(每个3D数组的行)。我想要将ddict的值合并为3D数组的每一行(我们需要合并与4个键对应的值,因为我们得到了一个2*2的数组),并找到最小值和最大值。我需要对3D数组的所有行都执行此操作。这有意义吗? - Shew
是的,那很有道理 - 请参见下面的答案。如果这不是你所需要的,请随意发表评论。 - Grismar
@Grismar,代码在我提供的输入数据上可以运行,但在我的真实代码上无法运行。我意识到当我粘贴代码时犯了一个大错误。我忘记将数组制作成结构化数组。当我在新数据上使用您的代码时,我得到了“TypeError: unhashable type: 'writeable void-scalar'”的错误提示。 - Shew
我已经查看并更新了答案。 - Grismar
1个回答

3

我认为这就是你想要的内容:

import numpy as np

# I've reformatted your example data, to make it a bit clearer
# no change in content though, just different whitespace
# whether d is a dict or defaultdict doesn't matter
d = {
    (1.0, 1.0): [1, 2, 3, 4],
    (1.0, 2.5): [2, 3, 4, 5],
    (1.0, 3.75): [],
    (1.5, 1.0): [8, 9, 10],
    (1.5, 2.5): [2, 6, 8, 19, 1, 31],
    (1.5, 3.75): [4]
}

# indA is just an array of indices, avoid capitals in variable names
indices = np.array(
    [
        [[(1.0, 1.0), (1.0, 3.75)], [(1.5, 1.0), (1.5, 3.75)]],
        [[(1.0, 2.5), (1.5, 1.0)], [(1.5, 2.5), (1.5, 3.75)]]
    ])

for group in indices:
    # you flattened each grouping of indices, but that flattens
    # the tuples you need intact as well:
    print('not: ', group.flatten())
    # Instead, you just want all the tuples:
    print('but: ', group.reshape(-1, group.shape[-1]))

# with that knowledge, this is how you can get the lists you want
# the min and max for
for group in indices:
    group = group.reshape(-1, group.shape[-1])
    values = list(x for key in group for x in d[tuple(key)])
    print(values)

# So, the solution:
result = [
    (min(vals), max(vals)) for vals in (
        list(x for key in grp.reshape(-1, grp.shape[-1]) for x in d[tuple(key)])
        for grp in indices
    )
]
print(result)

输出:

not:  [1.   1.   1.   3.75 1.5  1.   1.5  3.75]
but:  [[1.   1.  ]
 [1.   3.75]
 [1.5  1.  ]
 [1.5  3.75]]
not:  [1.   2.5  1.5  1.   1.5  2.5  1.5  3.75]
but:  [[1.   2.5 ]
 [1.5  1.  ]
 [1.5  2.5 ]
 [1.5  3.75]]
[1, 2, 3, 4, 8, 9, 10, 4]
[2, 3, 4, 5, 8, 9, 10, 2, 6, 8, 19, 1, 31, 4]
[(1, 10), (1, 31)]

也就是说,[(1, 10), (1, 31)] 就是您想要的结果,其中 1 是第一组索引的组合值中最小值,10 是同一组数值的最大值,等等。

以下是一些关键行的解释:

values = list(x for key in group for x in d[tuple(key)])

该代码通过循环遍历组中每对键/值对,并使用它们作为字典d的索引来构建一个组合值列表。但是,由于在重新整形后键将成为一个ndarray,因此首先将其传递给tuple()函数,以便正确地索引dict。它循环遍历检索到的值,并将每个值x添加到结果列表中。
解决方案在单个推导式中完成:
[
    (min(vals), max(vals)) for vals in (
        list(x for key in grp.reshape(-1, grp.shape[-1]) for x in d[tuple(key)])
        for grp in indices
    )
]

外层括号表示正在构建一个列表。 (min(vals),max(vals))vals的最小值和最大值的元组,vals循环遍历内部推导式。 内部推导式是一个生成器(用圆括号而不是方括号),为indices中每个组生成列表,就像上面解释的那样。

编辑:您更新了问题,添加了 dtype 到索引中,使其成为结构化数组,如下所示:

indices = np.array(
    [
        [[(1.0, 1.0), (1.0, 3.75)], [(1.5, 1.0), (1.5, 3.75)]],
        [[(1.0, 2.5), (1.5, 1.0)], [(1.5, 2.5), (1.5, 3.75)]]
    ], dtype='float16,float16')

为了应对这种变化并使解决方案仍然可行,您可以简单地使用非结构化副本进行操作:
unstructured_indices = rf.structured_to_unstructured(indices)
for group in unstructured_indices:
    group = group.reshape(-1, group.shape[-1])
    values = list(x for key in group for x in d[tuple(key)])
    print(values)

解决方案如下:

result = [
    (min(vals), max(vals)) for vals in (
        list(x for key in grp.reshape(-1, grp.shape[-1]) for x in d[tuple(key)])
        for grp in rf.structured_to_unstructured(indices)
    )
]

代码在我提供的输入数据上可以运行,但在我的真实代码上无法运行。我意识到当我粘贴代码时犯了一个大错误,我忘记将数组制作成结构化数组。当我在新数据上使用您的代码时,出现了“TypeError: unhashable type: 'writeable void-scalar'”的错误。 - Shew
错误是难免的 - 我已经更新了答案以应对这种情况。 - Grismar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接