哪些函数或模块需要连续的输入?

4
据我所理解,每当某些函数或模块需要一个连续的张量时,您需要显式地调用tensor.contiguous()。否则,您会遇到以下异常:
RuntimeError: invalid argument 1: input is not contiguous at .../src/torch/lib/TH/generic/THTensor.c:231

例如:通过哪些功能或模块需要连续的输入? 这是否有记录?
或者换句话说,什么情况下需要调用contiguous
例如,Conv1d,它需要连续的输入吗?文档没有提到这一点。当文档没有提到这一点时,这总是意味着它不需要连续的输入吗?
(我记得在Theano中,任何获取非连续输入的操作都会自动将其转换为连续的输入。)
3个回答

2
通过源代码的进一步挖掘,似乎只有view函数在传递非连续输入时明确引发异常。

人们会期望使用张量视图的任何操作都有可能因为非连续输入而失败。实际上,似乎大多数或所有这些函数:

(a.) 支持非连续块实现(见下面的示例),即张量迭代器可以处理指向内存中各个数据块的多个指针,但可能会牺牲性能,或者

(b.) 调用.contiguous()包装操作(一个这样的例子显示在这里,是对torch.tensor.diagflat()的调用)。reshape本质上是viewcontiguous()包装形式。

通过扩展,似乎 view 相对于 reshape 的主要优点就是当张量意外地不连续时,可以显式地触发异常,而代码会默默地处理这种差异,以性能为代价。
这个结论基于以下几点:
  1. 测试了所有具有非连续输入的 Tensor View 操作。
  2. 源代码分析了其他感兴趣的非 Tensor View 函数(例如 Conv1D),其中包括在所有非平凡输入情况下必要的调用 contiguous
  3. 从 PyTorch 的设计哲学中推断出来,它是一种简单、有时慢速、易于使用的语言。
  4. Pytorch Discuss 上进行了交叉发布。
  5. 广泛审查了涉及非连续错误的 Web 报告错误,所有这些错误都围绕着对 view 的问题调用。
我没有全面测试所有 PyTorch 函数,因为有成千上万个函数。
(a.)的例子:
import torch
import numpy
import time

# allocation 
start = time.time()
test = torch.rand([10000,1000,100])
torch.cuda.synchronize()
end = time.time()
print("Allocation took {} sec. Data is at address {}. Contiguous: 
{}".format(end - 
start,test.storage().data_ptr(),test.is_contiguous()))

# view of a contiguous tensor
start = time.time()
test.view(-1)
torch.cuda.synchronize()
end = time.time()
print("view() took {} sec. Data is at address {}. Contiguous: 
{}".format(end - 
start,test.storage().data_ptr(),test.is_contiguous()))


# diagonal() on a contiguous tensor
start = time.time()
test.diagonal()
torch.cuda.synchronize()
end = time.time()
print("diagonal() took {} sec. Data is at address {}. Contiguous: 
{}".format(end - 
start,test.storage().data_ptr(),test.is_contiguous()))


# Diagonal and a few tensor view ops on a non-contiguous tensor
test = test[::2,::2,::2]    # indexing is a Tensor View op 
resulting in a non-contiguous output
print(test.is_contiguous()) # False
start = time.time()
test = test.unsqueeze(-1).expand([test.shape[0],test.shape[1],test.shape[2],100]).diagonal()
torch.cuda.synchronize()
end = time.time()
print("non-contiguous tensor ops() took {} sec. Data is at 
address {}. Contiguous: {}".format(end - 
start,test.storage().data_ptr(),test.is_contiguous()))

# reshape, which requires a tensor copy operation to new memory
start = time.time()
test = test.reshape(-1) + 1.0
torch.cuda.synchronize()
end = time.time()
print("reshape() took {} sec. Data is at address {}. Contiguous: {}".format(end - start,test.storage().data_ptr(),test.is_contiguous()))

以下是输出内容:
Allocation took 4.269254922866821 sec. Data is at address 139863636672576. Contiguous: True
view() took 0.0002810955047607422 sec. Data is at address 139863636672576. Contiguous: True
diagonal() took 6.532669067382812e-05 sec. Data is at address 139863636672576. Contiguous: True
False
non-contiguous tensor ops() took 0.00011277198791503906 sec. Data is at address 139863636672576. Contiguous: False
reshape() took 0.13828253746032715 sec. Data is at address 94781254337664. Contiguous: True

在第4个块中,有一些张量视图操作是在一个非连续的输入张量上执行的。这些操作能够正常运行,并且保持数据在相同的内存地址中,比需要将数据复制到新的内存地址(例如第5个块中的reshape)的操作运行得更快。因此,看起来这些操作是以一种处理非连续输入而不需要数据复制的方式实现的。

我在询问哪些操作不支持非连续输入。所以你举了viewreshape作为例子。还有什么?这就是我的问题。我不是在问不支持多维输入的操作或其他类似的情况。这些都有很好的文档说明。 - Albert
答案是任何使用张量视图的东西。 - DerekG
另外,@Albert,reshape 明确不需要连续的输入,它是一种惰性复制操作,如果可能的话执行与 view 相同的功能,否则将数据复制到新的张量中。 - DerekG
我不理解这个答案。我在询问需要连续输入的操作列表。您之前写道view需要连续输入,那么这不是正确的吗?现在您改变了答案,但它并没有真正回答我的问题。有哪些操作、函数或模块需要连续输入?您能给出一些例子吗? - Albert
Pytorch 明确提供了一系列使用 Tensor Views 的操作,这些操作需要连续的输入。完整的函数列表在答案开头的链接中给出。答案的其余部分试图解释对于那些不熟悉 Pytorch 这些方面的人来说,需要连续张量/张量视图的原因。 - DerekG
显示剩余6条评论

0

来自pytorch文档:contiguous() → Tensor。返回一个包含与self张量相同数据的连续张量。如果self张量是连续的,则此函数返回self张量。


这不是问题的关键。 - Albert
目前你的回答不够清晰,请编辑并添加更多细节,以帮助其他人理解它如何回答问题。你可以在帮助中心找到有关如何撰写好答案的更多信息。 - Community

0

我认为没有一个完整的列表。这取决于您如何实现张量处理函数。

如果您查看有关编写C++和CUDA扩展的教程,您会发现典型的pytorch CUDA操作如下:

  • C++接口,带有torch::Tensor参数。该类提供了访问/操作张量数据的API。
  • CUDA内核,带有float*参数。这些指针直接指向存储张量数据的内存。

显然,使用指针处理张量中的数据比使用张量类的API更有效率。但最好使用具有连续内存布局(或至少是规则布局)的指针。

我相信原则上即使没有连续的数据,也可以通过指针操纵数据,只要给出足够的内存布局信息。但是您必须考虑各种布局,并且代码可能会更加繁琐。

Facebook可能有某些技巧,使一些内置操作适用于非连续数据(我对此并不了解),但大多数自定义扩展模块要求输入是连续的。


是的,对于所有自定义用户实现,你永远无法知道(尽管如果你将正确的步幅传递给内核并且不直接访问索引,则仍然没问题,但教程没有这样做)。但是我想知道内置函数。例如,Conv1d和许多其他功能。那么这只是通过试错吗? - Albert

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接