我正在使用tqdm库,但它没有给我进度条,而是只输出类似于这样的内容,告诉我迭代次数:
251it [01:44, 2.39it/s]
你知道代码为什么会这样做吗?我认为可能是因为我把一个生成器传递给它了,但我之前用过一些生成器是可以正常工作的。我以前从未真正涉及过tqdm格式化。以下是源代码的部分内容:
train_iter = zip(train_x, train_y) #train_x and train_y are just lists of elements
....
def train(train_iter, model, criterion, optimizer):
model.train()
total_loss = 0
for x, y in tqdm(train_iter):
x = x.transpose(0, 1)
y = y.transpose(0, 1)
optimizer.zero_grad()
bloss = model.forward(x, y, criterion)
bloss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
optimizer.step()
total_loss += bloss.data[0]
return total_loss