我正在尝试运行PyTorch CIFAR10图像分类的教程,链接在这里 - http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
我做了一个小改变,现在使用另一个数据集。我有来自Wikiart数据集的图像,想按艺术家对它们进行分类(标签=艺术家姓名)。
下面是网络(Net)的代码 -
这行代码:
现在,我不确定如何更改我的Net中的Conv2d以与
我得到了这个错误:
我尝试过的事情:
1) CIFAR10教程使用一个我没有使用的变换。 我无法将其纳入我的代码中。
2) 将
任何有关如何在PyTorch中计算输入和输出大小或自动重塑张量的资源都将非常感激。 我刚开始学习Torch,我觉得大小计算很复杂。
下面是网络(Net)的代码 -
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
接下来是代码的这一部分,我开始训练网络。
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(wiki_train_dataloader, 0):
inputs, labels = data['image'], data['class']
print(inputs.shape)
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.data[0]
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
这行代码:
print(inputs.shape)
在我的Wikiart数据集上返回torch.Size([4, 32, 32, 3])
,而在原始的CIFAR10示例中,它打印了torch.Size([4, 3, 32, 32])
。现在,我不确定如何更改我的Net中的Conv2d以与
torch.Size([4, 32, 32, 3])
兼容。我得到了这个错误:
RuntimeError: Given input size: (3 x 32 x 3). Calculated output size: (6 x 28 x -1). Output size is too small at /opt/conda/conda-bld/pytorch_1503965122592/work/torch/lib/THNN/generic/SpatialConvolutionMM.c:45
当读取Wikiart数据集的图像时,我将它们调整为(32,32)大小并且是三通道图像。我尝试过的事情:
1) CIFAR10教程使用一个我没有使用的变换。 我无法将其纳入我的代码中。
2) 将
self.conv2 = nn.Conv2d(6, 16, 5)
更改为self.conv2 = nn.Conv2d(3, 6, 5)
。 这给了我与上述相同的错误。 我只改变了这个来看看错误消息是否改变。任何有关如何在PyTorch中计算输入和输出大小或自动重塑张量的资源都将非常感激。 我刚开始学习Torch,我觉得大小计算很复杂。