给定一个 PyTorch 整数张量,例如 torch.Size(N, C, H, W)
。是否有一种有效的方法来哈希张量的每个元素,使得我可以获得从 [-MAX_INT32 to +MAX_INT32]
或 [0 to MAX_INT32]
的输出,快速在 GPU 上运行?
同时,我能够执行 output % N
,并且每个元素将均匀或几乎均匀地分布在 0 - N 之间。
def hash_tensor(tensor):
return hash(tuple(tensor.reshape(-1).tolist()))
>>> hash_tensor(torch.arange(8)) == hash_tensor(torch.arange(8))
True
>>> hash_tensor(torch.arange(8)) == hash_tensor(torch.arange(8) + 42)
False
hash
输出在解释器重新运行之间是不稳定的。 - undefinedimport torch
from torch import Tensor
MULTIPLIER = 6364136223846793005
INCREMENT = 1
MODULUS = 2**64
def hash_tensor(x: Tensor) -> Tensor:
assert x.dtype == torch.int64
while x.ndim > 0:
x = _reduce_last_axis(x)
return x
@torch.no_grad()
def _reduce_last_axis(x: Tensor) -> Tensor:
assert x.dtype == torch.int64
acc = torch.zeros_like(x[..., 0])
for i in range(x.shape[-1]):
acc *= MULTIPLIER
acc += INCREMENT
acc += x[..., i]
# acc %= MODULUS # Not really necessary.
return acc
x = torch.arange(8 * 8 * 8 * 8, dtype=torch.int64).reshape(8, 8, 8, 8)
>>> hash_tensor(x)
tensor(2150010819114838296)
>>> hash_tensor(x - 1)
tensor(2225417619311630616)
>>> hash_tensor(x * 2)
tensor(-4806624569712897768)
# Zeroes are OK:
>>> hash_tensor(torch.zeros((8, 8, 8, 8), dtype=torch.int64))
tensor(9106646207942574360)
# "Breaks" if you do a single axis of -1's...
>>> hash_tensor(torch.full((8,), -1))
tensor(0)
# But multiple axes is OKish:
>>> hash_tensor(torch.full((8, 8, 8, 8), -1))
tensor(9182053008139366680)
hash(tensor)
呢? - undefinedhash(tensor) == hash(tensor.clone())
。也许hash(tuple(tensor.reshape(-1).tolist()))
更好一些。 - undefined