抱歉标题描述不太好,但我不确定更好的方式来描述这个问题。
最近我看了Andrej Kaparthy的构建GPT视频,非常棒。现在我正在尝试重构代码时,注意到他使用self()作为函数,并且很好奇它是为什么以及究竟做了什么。
这里是代码,我特别想知道generate函数:
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
def forward(self, idx, targets=None):
# idx and targets are both (B,T) tensor of integers
logits = self.token_embedding_table(idx) # (B,T,C)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# get the predictions
logits, loss = self(idx)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
所以在我看来,他是通过使用self()来调用类中定义的forward函数。这样做正确吗?如果是这样,为什么他不使用forward(idx)
呢?感谢您的帮助!
nn
?... - Kelly Bundyself
是类的一个实例。它正在“调用该实例”。就是这样。可以假设该实例是可调用的。由于您没有定义__call__
方法,因此我们只能推测它是继承而来的。 - juanpa.arrivillaga__call__
方法的实现。 - juanpa.arrivillaga