我正在阅读《使用fastai和PyTorch进行深度学习》一书,对于嵌入模块的作用仍然感到有些困惑。它似乎是一个简短且简单的网络,但我无法理解嵌入与没有偏置的线性函数之间的区别。我知道它执行了某种更快的计算版本的点积,其中一个矩阵是one-hot编码矩阵,另一个矩阵是嵌入矩阵。它这样做是为了选择数据的一部分?请指出我的错误所在。以下是书中展示的一个简单网络。
class DotProduct(Module):
def __init__(self, n_users, n_movies, n_factors):
self.user_factors = Embedding(n_users, n_factors)
self.movie_factors = Embedding(n_movies, n_factors)
def forward(self, x):
users = self.user_factors(x[:,0])
movies = self.movie_factors(x[:,1])
return (users * movies).sum(dim=1)