PyTorch 图像分类,数据集采用内置的 MNIST.

加载数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import visdom
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

torch.manual_seed(33)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=0)
1
2
3
4
5
6
7
8
9
10
11
12
13
train_ds = torchvision.datasets.MNIST(
"/workspace/disk1/datasets/",
train=True,
transform=transforms.ToTensor(),
download=True,
)

test_ds = torchvision.datasets.MNIST(
"/workspace/disk1/datasets/",
train=False,
transform=transforms.ToTensor(),
download=True,
)
1
2
3
batch_size = 64
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size)
1
imgs, labels = next(iter(train_dl))
1
imgs.shape
torch.Size([64, 1, 28, 28])
1
labels.shape
torch.Size([64])
1
2
3
4
5
6
plt.figure(figsize=(batch_size, 1))
for i, img in enumerate(imgs):
img_np = img.numpy().squeeze()
plt.subplot(1, batch_size, i + 1)
plt.imshow(img_np, cmap="gray")
plt.axis("off")

png

1
labels.data
tensor([3, 0, 5, 2, 3, 4, 5, 9, 1, 7, 4, 7, 8, 4, 2, 1, 7, 9, 8, 3, 4, 9, 7, 5,
        0, 2, 4, 2, 5, 7, 6, 4, 2, 8, 8, 5, 6, 0, 6, 4, 9, 5, 9, 9, 9, 4, 9, 8,
        8, 6, 9, 3, 2, 2, 2, 5, 0, 4, 9, 3, 0, 8, 3, 2])

创建模型

1
2
3
4
5
6
7
8
9
10
11
12
13
class MLPModel(nn.Module):
def __init__(self):
super(MLPModel, self).__init__()
self.linear1 = nn.Linear(28 * 28, 128)
self.linear2 = nn.Linear(128, 64)
self.linear3 = nn.Linear(64, 10)

def forward(self, inputs):
x = inputs.view(-1, 1 * 28 * 28)
x = torch.relu(self.linear1(x))
x = torch.relu(self.linear2(x))
logits = self.linear3(x)
return logits
1
2
model = MLPModel().to(device)
model
MLPModel(
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=10, bias=True)
)
1
loss_fn = torch.nn.CrossEntropyLoss()
1
2
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def train(dl, model, loss_fn, optimizer):
size = len(dl.dataset)
num_batches = len(dl)

train_loss, correct = 0, 0

for x, y in dl:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

with torch.no_grad():
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()

correct /= size
train_loss /= num_batches
return correct, train_loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def test(dl, model, loss_fn):
size = len(dl.dataset)
num_batches = len(dl)

test_loss, correct = 0, 0

with torch.no_grad():
for x, y in dl:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
test_loss += loss.item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()

correct /= size
test_loss /= num_batches
return correct, test_loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况
viz = visdom.Visdom(
server="http://localhost",
port=8097,
base_url="/visdom",
username="jinzhongxu",
password="123123",
)
win = "mnist"
opts = dict(
title="MNIST",
xlabel="epoch",
ylabel="loss and acc",
markers=True,
legend=["train_loss", "train_acc", "test_loss", "test_acc"],
)
viz.line(
[[0.0, 0.0, 0.0, 0.0]],
[0.0],
win=win,
opts=opts,
)
Setting up a new session...





'mnist'
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
epochs = 50

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
epoch_acc, epoch_loss = train(
dl=train_dl, model=model, loss_fn=loss_fn, optimizer=optimizer
)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
epoch_test_acc, epoch_test_loss = test(dl=test_dl, model=model, loss_fn=loss_fn)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
print(
f"epoch={epoch:2d}, train_loss={epoch_loss:.5f}, train_acc={epoch_acc:.5f}, test_loss={epoch_test_loss:.5f}, test_acc={epoch_test_acc:.5f}"
)
viz.line(
[[epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc]],
[epoch],
win=win,
update="append",
)

print("done!")
epoch= 0, train_loss=0.87553, train_acc=0.78492, test_loss=0.37957, test_acc=0.89600
epoch= 1, train_loss=0.34057, train_acc=0.90542, test_loss=0.29167, test_acc=0.91640
epoch= 2, train_loss=0.28640, train_acc=0.91793, test_loss=0.26019, test_acc=0.92450
epoch= 3, train_loss=0.25685, train_acc=0.92650, test_loss=0.24007, test_acc=0.92930
epoch= 4, train_loss=0.23372, train_acc=0.93333, test_loss=0.21881, test_acc=0.93560
epoch= 5, train_loss=0.21370, train_acc=0.93912, test_loss=0.20176, test_acc=0.93990
epoch= 6, train_loss=0.19669, train_acc=0.94297, test_loss=0.18557, test_acc=0.94430
epoch= 7, train_loss=0.18122, train_acc=0.94835, test_loss=0.17462, test_acc=0.94620
epoch= 8, train_loss=0.16796, train_acc=0.95127, test_loss=0.16615, test_acc=0.94940
epoch= 9, train_loss=0.15605, train_acc=0.95473, test_loss=0.15155, test_acc=0.95460
epoch=10, train_loss=0.14516, train_acc=0.95775, test_loss=0.14506, test_acc=0.95670
epoch=11, train_loss=0.13557, train_acc=0.96103, test_loss=0.13445, test_acc=0.95970
epoch=12, train_loss=0.12738, train_acc=0.96342, test_loss=0.13094, test_acc=0.96010
epoch=13, train_loss=0.11912, train_acc=0.96610, test_loss=0.12227, test_acc=0.96210
epoch=14, train_loss=0.11219, train_acc=0.96753, test_loss=0.11731, test_acc=0.96430
epoch=15, train_loss=0.10571, train_acc=0.96947, test_loss=0.11181, test_acc=0.96500
epoch=16, train_loss=0.09996, train_acc=0.97147, test_loss=0.10745, test_acc=0.96670
epoch=17, train_loss=0.09438, train_acc=0.97308, test_loss=0.10555, test_acc=0.96800
epoch=18, train_loss=0.08965, train_acc=0.97438, test_loss=0.10191, test_acc=0.96840
epoch=19, train_loss=0.08477, train_acc=0.97557, test_loss=0.09853, test_acc=0.96930
epoch=20, train_loss=0.08065, train_acc=0.97690, test_loss=0.09546, test_acc=0.96970
epoch=21, train_loss=0.07642, train_acc=0.97827, test_loss=0.09460, test_acc=0.97060
epoch=22, train_loss=0.07243, train_acc=0.97918, test_loss=0.09040, test_acc=0.97170
epoch=23, train_loss=0.06898, train_acc=0.98013, test_loss=0.08840, test_acc=0.97270
epoch=24, train_loss=0.06559, train_acc=0.98123, test_loss=0.08831, test_acc=0.97240
epoch=25, train_loss=0.06238, train_acc=0.98242, test_loss=0.08451, test_acc=0.97450
epoch=26, train_loss=0.05947, train_acc=0.98308, test_loss=0.08525, test_acc=0.97340
epoch=27, train_loss=0.05665, train_acc=0.98370, test_loss=0.08331, test_acc=0.97420
epoch=28, train_loss=0.05389, train_acc=0.98510, test_loss=0.08325, test_acc=0.97480
epoch=29, train_loss=0.05153, train_acc=0.98535, test_loss=0.08162, test_acc=0.97450
epoch=30, train_loss=0.04908, train_acc=0.98628, test_loss=0.07992, test_acc=0.97540
epoch=31, train_loss=0.04710, train_acc=0.98658, test_loss=0.07741, test_acc=0.97600
epoch=32, train_loss=0.04476, train_acc=0.98773, test_loss=0.07945, test_acc=0.97460
epoch=33, train_loss=0.04273, train_acc=0.98813, test_loss=0.07803, test_acc=0.97500
epoch=34, train_loss=0.04049, train_acc=0.98873, test_loss=0.07625, test_acc=0.97520
epoch=35, train_loss=0.03883, train_acc=0.98968, test_loss=0.07546, test_acc=0.97660
epoch=36, train_loss=0.03686, train_acc=0.99037, test_loss=0.07731, test_acc=0.97510
epoch=37, train_loss=0.03529, train_acc=0.99060, test_loss=0.07601, test_acc=0.97570
epoch=38, train_loss=0.03339, train_acc=0.99118, test_loss=0.07800, test_acc=0.97490
epoch=39, train_loss=0.03212, train_acc=0.99150, test_loss=0.07530, test_acc=0.97650
epoch=40, train_loss=0.03038, train_acc=0.99222, test_loss=0.07336, test_acc=0.97610
epoch=41, train_loss=0.02889, train_acc=0.99262, test_loss=0.07662, test_acc=0.97680
epoch=42, train_loss=0.02742, train_acc=0.99350, test_loss=0.07404, test_acc=0.97700
epoch=43, train_loss=0.02625, train_acc=0.99347, test_loss=0.07493, test_acc=0.97660
epoch=44, train_loss=0.02480, train_acc=0.99420, test_loss=0.07400, test_acc=0.97710
epoch=45, train_loss=0.02360, train_acc=0.99417, test_loss=0.07704, test_acc=0.97700
epoch=46, train_loss=0.02243, train_acc=0.99490, test_loss=0.07595, test_acc=0.97790
epoch=47, train_loss=0.02117, train_acc=0.99525, test_loss=0.07470, test_acc=0.97700
epoch=48, train_loss=0.02004, train_acc=0.99557, test_loss=0.07563, test_acc=0.97650
epoch=49, train_loss=0.01895, train_acc=0.99598, test_loss=0.07576, test_acc=0.97750
done!

损失和测试准确率曲线:
png