如何高效地获取Torch张量中最大值的索引?

29

假设有一个火炬张量,例如以下形状:

x = torch.rand(20, 1, 120, 120)
现在我想要的是得到每个120x120矩阵最大值的索引。为了简化问题,我会先使用x.squeeze()将其形状变为[20, 120, 120]。然后我想获得torch张量,它是一个形状为[20, 2]的索引列表。
如何快速实现?

你为什么需要一个 [20, 2] 矩阵?你想要在每个 120 * 120 的矩阵中沿着行和列找到最大值吗? - Kashyap
是的,换句话说:对于这20个120 * 120矩阵中的每一个,我想要具有最大值的单元格的[x, y]坐标。 - Chris
如果您想知道前k个元素的索引,请使用torch.topk() - tejasvi88
这个回答解决了你的问题吗?从1-D张量中提取前k个值的索引 - iacob
6个回答

19

torch.topk() 是你正在寻找的函数。根据文档,

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

返回给定维度上前 k 大的元素及其对应的索引。

  • 如果未指定参数 dim,则默认选择最后一维。

  • 如果参数 largestFalse,则返回前 k 小的元素。

  • 该函数返回一个由两个张量元素组成的命名元组 (values, indices),其中 indices 表示原始张量中元素的位置索引。

  • 当布尔型参数 sorted 等于 True,返回的前 k 个元素是已排序状态。


7
这是一个有用的函数,但它并没有回答原始问题。OP想要获得每个20个120x120矩阵中最大元素的索引。也就是说,她想要20个二维坐标,一个对应于一个矩阵。topk仅返回在被最大化的维度中最大元素的索引。 - user118967
请注意,topk函数的文档在返回索引的含义方面存在混淆。它给人的印象是该函数提供了原始张量的索引,而实际上它仅返回最大化维度的索引。请参阅pytorch问题https://github.com/pytorch/pytorch/issues/50331#issue-782748956以澄清此问题。 - user118967

10
如果我理解正确,您不需要值,而是需要索引。不幸的是,没有现成的解决方案。存在一个argmax()函数,但我无法看到如何让它完全符合您的要求。
因此,这里有一个小的解决方法,效率也应该可以,因为我们只是在分割张量:
n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)

n 代表您的第一维,d 代表最后两个维度。我这里使用较小的数字来展示结果。但当然也适用于 n=20d=120

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)

这里是 n=4d=4 的输出:

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])

我希望这是你想要的内容!:)
编辑:
这是稍微修改过的代码,可能会更快一些(我猜不会太多),但它更简单、更美观:
与之前的代码不同,使用以下方式:
而不是:
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

已经对argmax值进行了必要的重塑:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)

但是正如评论中提到的那样,我认为很难从中获得更多的信息。

如果你真的非常重视性能提升的最后一点,你可以将上面的函数实现为pytorch的低级扩展(例如C++)。

这将给你一个可以调用的函数,并避免慢速的Python代码。

https://pytorch.org/tutorials/advanced/cpp_extension.html


是的,那就是我想要的输出。我修改了它,用.float()m转换,然后在除以d时使用//。你提出的是一种展开方式,类似于numpy.unravel_indices()。如果你能想到更快的方法,当然会更好。 - Chris
@Chris 我刚刚进行了一个短时间测试。实际上,我认为它非常高效,我想目前没有更快的方法:调用argmax()本身需要大约10倍的时间来计算下一行中的索引 - 在CPU上,我稍后也可以检查GPU。但是这些操作非常简单和直接,因此即使这是一个解决方法,从理论上讲它应该也相当高效。 - MBT
1
不,它绝不会慢。 在Telsa Volta上我只需要大约5.5毫秒。 我只是需要将其最大化,但是我同意,argmax是一个线性操作,因为张量是无序的。 可能这就是最慢的组件,而且不能加速。 - Chris
@Chris 我在结尾处进行了小的编辑,得到了稍微更好的版本。但是我不会期望在性能方面有什么显著的改进,可能只会比原来快半纳秒左右。如果你真的很重视性能,可能需要使用C++编写自定义扩展。但考虑到代码片段很小,所获得的收益也可能不会太大。 - MBT
谢谢,运行良好。我在评估中也犯了一个错误,看起来只有0.5毫秒而不是5毫秒。 - Chris

4
这是torch中的unravel_index实现:
def unravel_index(
    indices: torch.LongTensor,
    shape: Tuple[int, ...],
) -> torch.LongTensor:
    r"""Converts flat indices into unraveled coordinates in a target shape.

    This is a `torch` implementation of `numpy.unravel_index`.

    Args:
        indices: A tensor of (flat) indices, (*, N).
        shape: The targeted shape, (D,).

    Returns:
        The unraveled coordinates, (*, N, D).
    """

    coord = []

    for dim in reversed(shape):
        coord.append(indices % dim)
        indices = indices // dim

    coord = torch.stack(coord[::-1], dim=-1)

    return coord

接下来,您可以使用torch.argmax函数来获取“扁平化”张量的索引。

y = x.view(20, -1)
indices = torch.argmax(y)
indices.shape  # (20,)

使用unravel_index函数来展开索引。

indices = unravel_index(indices, x.shape[-2:])
indices.shape  # (20, 2)

这是最接近真实、通用答案的回答!更直接地回答原始问题,即如何获取最大值的索引,您可能希望编辑以显示如何首先使用argmax获取索引,然后展开它们。 - user118967
最终我不得不编写与argmax的连接代码,请检查我的答案。请随意在您的答案中加入我的代码。 - user118967

0

被接受的答案仅适用于给定的示例。

tejasvi88的答案很有趣,但无法帮助回答原始问题(如我在那里的评论中所解释的)。

我认为Francois的答案最接近,因为它处理更通用的情况(任意数量的维度)。但是,它没有连接到argmax,并且所示示例未说明该函数处理批处理的能力。

因此,我将在Francois的答案基础上构建代码以连接到argmax。我编写了一个新函数batch_argmax,它返回批次中最大值的索引。批次可以组织在多个维度中。我还包括一些测试用例以供说明:

def batch_argmax(tensor, batch_dim=1):
    """
    Assumes that dimensions of tensor up to batch_dim are "batch dimensions"
    and returns the indices of the max element of each "batch row".
    More precisely, returns tensor `a` such that, for each index v of tensor.shape[:batch_dim], a[v] is
    the indices of the max element of tensor[v].
    """
    if batch_dim >= len(tensor.shape):
        raise NoArgMaxIndices()
    batch_shape = tensor.shape[:batch_dim]
    non_batch_shape = tensor.shape[batch_dim:]
    flat_non_batch_size = prod(non_batch_shape)
    tensor_with_flat_non_batch_portion = tensor.reshape(*batch_shape, flat_non_batch_size)

    dimension_of_indices = len(non_batch_shape)

    # We now have each batch row flattened in the last dimension of tensor_with_flat_non_batch_portion,
    # so we can invoke its argmax(dim=-1) method. However, that method throws an exception if the tensor
    # is empty. We cover that case first.
    if tensor_with_flat_non_batch_portion.numel() == 0:
        # If empty, either the batch dimensions or the non-batch dimensions are empty
        batch_size = prod(batch_shape)
        if batch_size == 0:  # if batch dimensions are empty
            # return empty tensor of appropriate shape
            batch_of_unraveled_indices = torch.ones(*batch_shape, dimension_of_indices).long()  # 'ones' is irrelevant as it will be empty
        else:  # non-batch dimensions are empty, so argmax indices are undefined
            raise NoArgMaxIndices()
    else:   # We actually have elements to maximize, so we search for them
        indices_of_non_batch_portion = tensor_with_flat_non_batch_portion.argmax(dim=-1)
        batch_of_unraveled_indices = unravel_indices(indices_of_non_batch_portion, non_batch_shape)

    if dimension_of_indices == 1:
        # above function makes each unraveled index of a n-D tensor a n-long tensor
        # however indices of 1D tensors are typically represented by scalars, so we squeeze them in this case.
        batch_of_unraveled_indices = batch_of_unraveled_indices.squeeze(dim=-1)
    return batch_of_unraveled_indices


class NoArgMaxIndices(BaseException):

    def __init__(self):
        super(NoArgMaxIndices, self).__init__(
            "no argmax indices: batch_argmax requires non-batch shape to be non-empty")

以下是测试:

def test_basic():
    # a simple array
    tensor = torch.tensor([0, 1, 2, 3, 4])
    batch_dim = 0
    expected = torch.tensor(4)
    run_test(tensor, batch_dim, expected)

    # making batch_dim = 1 renders the non-batch portion empty and argmax indices undefined
    tensor = torch.tensor([0, 1, 2, 3, 4])
    batch_dim = 1
    check_that_exception_is_thrown(lambda: batch_argmax(tensor, batch_dim), NoArgMaxIndices)

    # now a batch of arrays
    tensor = torch.tensor([[1, 2, 3], [6, 5, 4]])
    batch_dim = 1
    expected = torch.tensor([2, 0])
    run_test(tensor, batch_dim, expected)

    # Now we have an empty batch with non-batch 3-dim arrays' shape (the arrays are actually non-existent)
    tensor = torch.ones(0, 3)  # 'ones' is irrelevant since this is empty
    batch_dim = 1
    # empty batch of the right shape: just the batch dimension 0,since indices of arrays are scalar (0D)
    expected = torch.ones(0)
    run_test(tensor, batch_dim, expected)

    # Now we have an empty batch with non-batch matrices' shape (the matrices are actually non-existent)
    tensor = torch.ones(0, 3, 2)  # 'ones' is irrelevant since this is empty
    batch_dim = 1
    # empty batch of the right shape: the batch and two dimension for the indices since we have 2D matrices
    expected = torch.ones(0, 2)
    run_test(tensor, batch_dim, expected)

    # a batch of 2D matrices:
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = 1
    expected = torch.tensor([[1, 0], [1, 2]])  # coordinates of two 6's, one in each 2D matrix
    run_test(tensor, batch_dim, expected)

    # same as before, but testing that batch_dim supports negative values
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = -2
    expected = torch.tensor([[1, 0], [1, 2]])
    run_test(tensor, batch_dim, expected)

    # Same data, but a 2-dimensional batch of 1D arrays!
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = 2
    expected = torch.tensor([[2, 0], [1, 2]])  # coordinates of 3, 6, 3, and 6
    run_test(tensor, batch_dim, expected)

    # same as before, but testing that batch_dim supports negative values
    tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
    batch_dim = -1
    expected = torch.tensor([[2, 0], [1, 2]])
    run_test(tensor, batch_dim, expected)


def run_test(tensor, batch_dim, expected):
    actual = batch_argmax(tensor, batch_dim)
    print(f"batch_argmax of {tensor} with batch_dim {batch_dim} is\n{actual}\nExpected:\n{expected}")
    assert actual.shape == expected.shape
    assert actual.eq(expected).all()

def check_that_exception_is_thrown(thunk, exception_type):
    if isinstance(exception_type, BaseException):
        raise Exception(f"check_that_exception_is_thrown received an exception instance rather than an exception type: "
                        f"{exception_type}")
    try:
        thunk()
        raise AssertionError(f"Should have thrown {exception_type}")
    except exception_type:
        pass
    except Exception as e:
        raise AssertionError(f"Should have thrown {exception_type} but instead threw {e}")

0

我有一个直接的解决方法,但不是每个项目批量计算最大值的最佳方案的。简单的解决方法可能是:

# suppose the tensor is of shape (3,2,2), 
>>> a = torch.randn(3, 2, 2)
>>> a
tensor([[[ 0.1450, -1.3480],
         [-0.3339, -0.5133]],

        [[ 0.6867, -0.2972],
         [ 0.8768,  0.0844]],

        [[-2.3115, -0.4549],
         [-1.5074, -0.8706]]])

# then perform batch-wise max
>>> torch.stack([(a[i]==torch.max(a[i])).nonzero() for i in range(a.size(0))], dim=0)

tensor([[[0, 0]],

        [[1, 0]],

        [[0, 1]]])

-2
ps=ps.numpy()
ps=ps.tolist()

mx=[max(l) for l in ps]
mx=max(mx)
for i in range(len(ps[0])):
  if mx==ps[0][i]:
    print("The digit is "+str(i))
    break

这对我来说非常有效


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接