PyTorch 图像分类,数据集采用内置的 MNIST.
 
加载数据集 1 2 3 4 5 6 7 8 9 10 11 12 13 import  matplotlib.pyplot as  pltimport  numpy as  npimport  torchimport  torchvisionimport  visdomfrom  torch import  nnfrom  torch.utils.data import  DataLoaderfrom  torchvision import  transformstorch.manual_seed(33 ) device = torch.device("cuda:0"  if  torch.cuda.is_available() else  "cpu" ) device 
 
device(type='cuda', index=0)
 
1 2 3 4 5 6 7 8 9 10 11 12 13 train_ds = torchvision.datasets.MNIST(     "/workspace/disk1/datasets/" ,     train=True ,     transform=transforms.ToTensor(),     download=True , ) test_ds = torchvision.datasets.MNIST(     "/workspace/disk1/datasets/" ,     train=False ,     transform=transforms.ToTensor(),     download=True , ) 
 
1 2 3 batch_size = 64  train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True ) test_dl = DataLoader(test_ds, batch_size=batch_size) 
 
1 imgs, labels = next (iter (train_dl)) 
 
 
torch.Size([64, 1, 28, 28])
 
 
torch.Size([64])
 
1 2 3 4 5 6 plt.figure(figsize=(batch_size, 1 )) for  i, img in  enumerate (imgs):    img_np = img.numpy().squeeze()     plt.subplot(1 , batch_size, i + 1 )     plt.imshow(img_np, cmap="gray" )     plt.axis("off" ) 
 
 
tensor([3, 0, 5, 2, 3, 4, 5, 9, 1, 7, 4, 7, 8, 4, 2, 1, 7, 9, 8, 3, 4, 9, 7, 5,
        0, 2, 4, 2, 5, 7, 6, 4, 2, 8, 8, 5, 6, 0, 6, 4, 9, 5, 9, 9, 9, 4, 9, 8,
        8, 6, 9, 3, 2, 2, 2, 5, 0, 4, 9, 3, 0, 8, 3, 2])
 
创建模型 1 2 3 4 5 6 7 8 9 10 11 12 13 class  MLPModel (nn.Module):    def  __init__ (self ):         super (MLPModel, self).__init__()         self.linear1 = nn.Linear(28  * 28 , 128 )         self.linear2 = nn.Linear(128 , 64 )         self.linear3 = nn.Linear(64 , 10 )     def  forward (self, inputs ):         x = inputs.view(-1 , 1  * 28  * 28 )         x = torch.relu(self.linear1(x))         x = torch.relu(self.linear2(x))         logits = self.linear3(x)         return  logits 
 
1 2 model = MLPModel().to(device) model 
 
MLPModel(
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=10, bias=True)
)
 
1 loss_fn = torch.nn.CrossEntropyLoss() 
 
1 2 optimizer = torch.optim.Adam(model.parameters(), lr=0.0001 ) optimizer 
 
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def  train (dl, model, loss_fn, optimizer ):    size = len (dl.dataset)     num_batches = len (dl)     train_loss, correct = 0 , 0      for  x, y in  dl:         x, y = x.to(device), y.to(device)         pred = model(x)         loss = loss_fn(pred, y)         optimizer.zero_grad()         loss.backward()         optimizer.step()         with  torch.no_grad():             correct += (pred.argmax(1 ) == y).type (torch.float ).sum ().item()             train_loss += loss.item()     correct /= size     train_loss /= num_batches     return  correct, train_loss 
 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def  test (dl, model, loss_fn ):    size = len (dl.dataset)     num_batches = len (dl)     test_loss, correct = 0 , 0      with  torch.no_grad():         for  x, y in  dl:             x, y = x.to(device), y.to(device)             pred = model(x)             loss = loss_fn(pred, y)             test_loss += loss.item()             correct += (pred.argmax(1 ) == y).type (torch.float ).sum ().item()     correct /= size     test_loss /= num_batches     return  correct, test_loss 
 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 viz = visdom.Visdom(     server="http://localhost" ,     port=8097 ,     base_url="/visdom" ,     username="jinzhongxu" ,     password="123123" , ) win = "mnist"  opts = dict (     title="MNIST" ,     xlabel="epoch" ,     ylabel="loss and acc" ,     markers=True ,     legend=["train_loss" , "train_acc" , "test_loss" , "test_acc" ], ) viz.line(     [[0.0 , 0.0 , 0.0 , 0.0 ]],     [0.0 ],     win=win,     opts=opts, ) 
 
Setting up a new session...
'mnist'
 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 epochs = 50  train_loss = [] train_acc = [] test_loss = [] test_acc = [] for  epoch in  range (epochs):    epoch_acc, epoch_loss = train(         dl=train_dl, model=model, loss_fn=loss_fn, optimizer=optimizer     )     train_loss.append(epoch_loss)     train_acc.append(epoch_acc)     epoch_test_acc, epoch_test_loss = test(dl=test_dl, model=model, loss_fn=loss_fn)     test_loss.append(epoch_test_loss)     test_acc.append(epoch_test_acc)     print (         f"epoch={epoch:2d} , train_loss={epoch_loss:.5 f} , train_acc={epoch_acc:.5 f} , test_loss={epoch_test_loss:.5 f} , test_acc={epoch_test_acc:.5 f} "      )     viz.line(         [[epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc]],         [epoch],         win=win,         update="append" ,     ) print ("done!" )
 
epoch= 0, train_loss=0.87553, train_acc=0.78492, test_loss=0.37957, test_acc=0.89600
epoch= 1, train_loss=0.34057, train_acc=0.90542, test_loss=0.29167, test_acc=0.91640
epoch= 2, train_loss=0.28640, train_acc=0.91793, test_loss=0.26019, test_acc=0.92450
epoch= 3, train_loss=0.25685, train_acc=0.92650, test_loss=0.24007, test_acc=0.92930
epoch= 4, train_loss=0.23372, train_acc=0.93333, test_loss=0.21881, test_acc=0.93560
epoch= 5, train_loss=0.21370, train_acc=0.93912, test_loss=0.20176, test_acc=0.93990
epoch= 6, train_loss=0.19669, train_acc=0.94297, test_loss=0.18557, test_acc=0.94430
epoch= 7, train_loss=0.18122, train_acc=0.94835, test_loss=0.17462, test_acc=0.94620
epoch= 8, train_loss=0.16796, train_acc=0.95127, test_loss=0.16615, test_acc=0.94940
epoch= 9, train_loss=0.15605, train_acc=0.95473, test_loss=0.15155, test_acc=0.95460
epoch=10, train_loss=0.14516, train_acc=0.95775, test_loss=0.14506, test_acc=0.95670
epoch=11, train_loss=0.13557, train_acc=0.96103, test_loss=0.13445, test_acc=0.95970
epoch=12, train_loss=0.12738, train_acc=0.96342, test_loss=0.13094, test_acc=0.96010
epoch=13, train_loss=0.11912, train_acc=0.96610, test_loss=0.12227, test_acc=0.96210
epoch=14, train_loss=0.11219, train_acc=0.96753, test_loss=0.11731, test_acc=0.96430
epoch=15, train_loss=0.10571, train_acc=0.96947, test_loss=0.11181, test_acc=0.96500
epoch=16, train_loss=0.09996, train_acc=0.97147, test_loss=0.10745, test_acc=0.96670
epoch=17, train_loss=0.09438, train_acc=0.97308, test_loss=0.10555, test_acc=0.96800
epoch=18, train_loss=0.08965, train_acc=0.97438, test_loss=0.10191, test_acc=0.96840
epoch=19, train_loss=0.08477, train_acc=0.97557, test_loss=0.09853, test_acc=0.96930
epoch=20, train_loss=0.08065, train_acc=0.97690, test_loss=0.09546, test_acc=0.96970
epoch=21, train_loss=0.07642, train_acc=0.97827, test_loss=0.09460, test_acc=0.97060
epoch=22, train_loss=0.07243, train_acc=0.97918, test_loss=0.09040, test_acc=0.97170
epoch=23, train_loss=0.06898, train_acc=0.98013, test_loss=0.08840, test_acc=0.97270
epoch=24, train_loss=0.06559, train_acc=0.98123, test_loss=0.08831, test_acc=0.97240
epoch=25, train_loss=0.06238, train_acc=0.98242, test_loss=0.08451, test_acc=0.97450
epoch=26, train_loss=0.05947, train_acc=0.98308, test_loss=0.08525, test_acc=0.97340
epoch=27, train_loss=0.05665, train_acc=0.98370, test_loss=0.08331, test_acc=0.97420
epoch=28, train_loss=0.05389, train_acc=0.98510, test_loss=0.08325, test_acc=0.97480
epoch=29, train_loss=0.05153, train_acc=0.98535, test_loss=0.08162, test_acc=0.97450
epoch=30, train_loss=0.04908, train_acc=0.98628, test_loss=0.07992, test_acc=0.97540
epoch=31, train_loss=0.04710, train_acc=0.98658, test_loss=0.07741, test_acc=0.97600
epoch=32, train_loss=0.04476, train_acc=0.98773, test_loss=0.07945, test_acc=0.97460
epoch=33, train_loss=0.04273, train_acc=0.98813, test_loss=0.07803, test_acc=0.97500
epoch=34, train_loss=0.04049, train_acc=0.98873, test_loss=0.07625, test_acc=0.97520
epoch=35, train_loss=0.03883, train_acc=0.98968, test_loss=0.07546, test_acc=0.97660
epoch=36, train_loss=0.03686, train_acc=0.99037, test_loss=0.07731, test_acc=0.97510
epoch=37, train_loss=0.03529, train_acc=0.99060, test_loss=0.07601, test_acc=0.97570
epoch=38, train_loss=0.03339, train_acc=0.99118, test_loss=0.07800, test_acc=0.97490
epoch=39, train_loss=0.03212, train_acc=0.99150, test_loss=0.07530, test_acc=0.97650
epoch=40, train_loss=0.03038, train_acc=0.99222, test_loss=0.07336, test_acc=0.97610
epoch=41, train_loss=0.02889, train_acc=0.99262, test_loss=0.07662, test_acc=0.97680
epoch=42, train_loss=0.02742, train_acc=0.99350, test_loss=0.07404, test_acc=0.97700
epoch=43, train_loss=0.02625, train_acc=0.99347, test_loss=0.07493, test_acc=0.97660
epoch=44, train_loss=0.02480, train_acc=0.99420, test_loss=0.07400, test_acc=0.97710
epoch=45, train_loss=0.02360, train_acc=0.99417, test_loss=0.07704, test_acc=0.97700
epoch=46, train_loss=0.02243, train_acc=0.99490, test_loss=0.07595, test_acc=0.97790
epoch=47, train_loss=0.02117, train_acc=0.99525, test_loss=0.07470, test_acc=0.97700
epoch=48, train_loss=0.02004, train_acc=0.99557, test_loss=0.07563, test_acc=0.97650
epoch=49, train_loss=0.01895, train_acc=0.99598, test_loss=0.07576, test_acc=0.97750
done!
 
损失和测试准确率曲线: