tqdm不显示进度条

32

我正在使用tqdm库,但它没有给我进度条,而是只输出类似于这样的内容,告诉我迭代次数:

251it [01:44, 2.39it/s]

你知道代码为什么会这样做吗?我认为可能是因为我把一个生成器传递给它了,但我之前用过一些生成器是可以正常工作的。我以前从未真正涉及过tqdm格式化。以下是源代码的部分内容:

train_iter = zip(train_x, train_y) #train_x and train_y are just lists of elements
....
def train(train_iter, model, criterion, optimizer):
    model.train()
    total_loss = 0
    for x, y in tqdm(train_iter):
        x = x.transpose(0, 1)
        y = y.transpose(0, 1)
        optimizer.zero_grad()
        bloss = model.forward(x, y, criterion)   
        bloss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        optimizer.step()        
        total_loss += bloss.data[0]
    return total_loss

1
希望 train_iter 是可迭代的并且不为 None。尝试检查 train_iter 数据及其类型。 - Reck
3个回答

47
< p >< code > tqdm < /code > 需要知道将执行多少次迭代(总数)才能显示进度条。 < p > 您可以尝试这样做:< /p >
from tqdm import tqdm

train_x = range(100)
train_y = range(200)

train_iter = zip(train_x, train_y)

# Notice `train_iter` can only be iter over once, so i get `total` in this way.
total = min(len(train_x), len(train_y))

with tqdm(total=total) as pbar:
    for item in train_iter:
        # do something ...
        pbar.update(1)

这里也有同样的问题,正在从psql的结果集中进行迭代。 - arilwan
1
对于我的情况,我在total值上犯了一个错误,tqdm update传递了总数,进度条闪过,然后回退到这种样式:251it [01:44, 2.39it/s] - xiaobing
确认 @xiaobing 所说的,如果超过总数,tqdm 将会回退到那种样式。注意,由于溢出,总数可能最终成为负数,因此如果 tqdm 立即还原,请检查一下。 - David Cian

15

将“total”参数填充为长度对我有效。现在进度条已经出现。

from tqdm import tqdm

# ...
for imgs, targets in tqdm( train_dataloader, total=len(train_dataloader)):
   # ...

0

@Dogus的回答更自然地使用了tqdm,但您需要确保您的数据加载器(如果是自定义迭代器)也公开了len方法。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接