# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
net = net.to(device)
print("training on ", device)
batch_count = 0
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = d2l.evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
%% Below, type any markdown to display in the Graffiti tip.
%% Then run this cell to save it.
train_iter = load_cifar10(True, train_augs, batch_size) test_iter = load_cifar10(False, test_augs, batch_size)
training on cpu
epoch 1, loss 1.3790, train acc 0.504, test acc 0.554, time 195.8 sec
epoch 2, loss 0.4992, train acc 0.646, test acc 0.592, time 192.5 sec
epoch 3, loss 0.2821, train acc 0.702, test acc 0.657, time 193.7 sec
epoch 4, loss 0.1859, train acc 0.739, test acc 0.693, time 195.4 sec
epoch 5, loss 0.1349, train acc 0.766, test acc 0.688, time 192.6 sec
epoch 6, loss 0.1022, train acc 0.786, test acc 0.701, time 200.2 sec
epoch 7, loss 0.0797, train acc 0.806, test acc 0.720, time 191.8 sec
epoch 8, loss 0.0633, train acc 0.825, test acc 0.695, time 198.6 sec
epoch 9, loss 0.0524, train acc 0.836, test acc 0.693, time 192.1 sec
epoch 10, loss 0.0437, train acc 0.850, test acc 0.769, time 196.3 sec