我实现了一些MaxPool2d的代码(与PyTorch进行了比较,能够正确运行)。但是,在对MNIST数据集进行测试时,这个函数(updateOutput)需要很长时间才能完成。如何使用NumPy优化此代码?
class MaxPool2d(Module):
def __init__(self, kernel_size):
super(MaxPool2d, self).__init__()
self.kernel_size = kernel_size
self.gradInput = None
def updateOutput(self, input):
#print("MaxPool updateOutput")
#start_time = time.time()
kernel = self.kernel_size
poolH = input.shape[2] // kernel
poolW = input.shape[3] // kernel
self.output = np.zeros((input.shape[0],
input.shape[1],
poolH,
poolW))
self.index = np.zeros((input.shape[0],
input.shape[1],
poolH,
poolW,
2),
dtype='int32')
for i in range(input.shape[0]):
for j in range(input.shape[1]):
for k in range(0, input.shape[2] - kernel+1, kernel):
for m in range(0, input.shape[3] - kernel+1, kernel):
M = input[i, j, k : k+kernel, m : m+kernel]
self.output[i, j, k // kernel, m // kernel] = M.max()
self.index[i, j, k // kernel, m // kernel] = np.array(np.unravel_index(M.argmax(), M.shape)) + np.array((k, m))
#print(f"time: {time.time() - start_time:.3f}s")
return self.output
输入形状 = (批量大小,输入通道数,高度,宽度)
输出形状 = (批量大小,输出通道数,高度 // 卷积核大小,宽度 // 卷积核大小)