PyTorch 逻辑回归,数据集:UCI Iris Data Set
导入依赖包 1 2 3 4 5 6 7 import matplotlib.pyplot as pltimport numpy as npimport pandas as pdimport torchimport torch.nn as nnimport torch.optim as optimimport visdom
1 torch.__version__, torch.manual_seed(33 )
('1.12.1+cu102', <torch._C.Generator at 0x7f753ca60e90>)
加载数据集 1 path = "/workspace/disk1/datasets/scalar/iris.data.csv"
1 2 data = pd.read_csv(path) data.head(3 )
5.1
3.5
1.4
0.2
Iris-setosa
0
4.9
3.0
1.4
0.2
Iris-setosa
1
4.7
3.2
1.3
0.2
Iris-setosa
2
4.6
3.1
1.5
0.2
Iris-setosa
149
1 set (data["Iris-setosa" ].values)
{'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'}
1 2 3 data_setosa = data[data["Iris-setosa" ] == "Iris-setosa" ] data_versicolor = data[data["Iris-setosa" ] == "Iris-versicolor" ] data_virginica = data[data["Iris-setosa" ] == "Iris-virginica" ]
1 len (data_setosa), len (data_versicolor), len (data_virginica)
(49, 50, 50)
1 2 3 4 data_using = data[data["Iris-setosa" ] != "Iris-setosa" ] X = data_using.iloc[:, :-1 ] Y = data_using.iloc[:, -1 ].replace("Iris-versicolor" , 0 ).replace("Iris-virginica" , 1 ) Y.unique()
array([0, 1])
1 2 X = torch.from_numpy(X.values).type (torch.float32) X.shape
torch.Size([100, 4])
1 2 Y = torch.from_numpy(Y.values.reshape(-1 , 1 )).type (torch.float32) Y.shape
torch.Size([100, 1])
1 2 3 4 5 6 train_set_ratio = 0.8 train_set_num = int (train_set_ratio * X.shape[0 ]) X_train = X[:train_set_num] Y_train = Y[:train_set_num] X_test = X[train_set_num:] Y_test = Y[train_set_num:]
构建模型 1 2 3 4 model = nn.Sequential( nn.Linear(in_features=X.shape[1 ], out_features=Y.shape[1 ]), nn.Sigmoid(), )
Sequential(
(0): Linear(in_features=4, out_features=1, bias=True)
(1): Sigmoid()
)
1 optimizer = optim.Adam(model.parameters(), lr=0.001 )
1 2 3 batch = 20 num_of_batch = X_train.shape[0 ] // batch epochs = 20000
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 viz = visdom.Visdom( server="http://localhost" , port=8097 , base_url="/visdom" , username="jinzhongxu" , password="123123" , ) win = "logistic regression loss" opts = dict ( title="train_losses" , xlabel="epoch" , ylabel="loss" , markers=True , legend=["loss" , "acc" ], ) viz.line( [ [ 0.0 , 0.0 , ] ], [0.0 ], win=win, opts=opts, )
Setting up a new session...
'logistic regression loss'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 for epoch in range (epochs): losses = [] for n in range (num_of_batch): start = n * batch end = (n + 1 ) * batch x = X_train[start:end] y = Y_train[start:end] y_pred = model(x) loss = loss_fn(y_pred, y) losses.append(loss.data) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 100 == 0 : acc = ( (model(X_test).data.numpy() > 0.5 ).astype(np.int8) == Y_test.numpy() ).mean() print ( f"epoch: {str (epoch).rjust(len (str (epochs)), '0' )} , loss:{np.mean(losses):.5 f} , acc: {acc:.3 f} " ) viz.line( [ [ np.mean(losses), acc, ] ], [epoch], win=win, update="append" , )
epoch: 00000, loss:0.81945, acc: 0.000
epoch: 00100, loss:0.72729, acc: 0.000
epoch: 00200, loss:0.71364, acc: 0.000
epoch: 00300, loss:0.70016, acc: 0.000
epoch: 00400, loss:0.68695, acc: 0.000
epoch: 00500, loss:0.67402, acc: 0.000
epoch: 00600, loss:0.66138, acc: 0.000
epoch: 00700, loss:0.64902, acc: 0.000
epoch: 00800, loss:0.63694, acc: 0.000
epoch: 00900, loss:0.62514, acc: 0.000
epoch: 01000, loss:0.61361, acc: 0.000
epoch: 01100, loss:0.60233, acc: 0.000
epoch: 01200, loss:0.59132, acc: 0.100
epoch: 01300, loss:0.58056, acc: 0.250
epoch: 01400, loss:0.57005, acc: 0.350
epoch: 01500, loss:0.55979, acc: 0.400
epoch: 01600, loss:0.54976, acc: 0.450
epoch: 01700, loss:0.53998, acc: 0.450
epoch: 01800, loss:0.53042, acc: 0.500
epoch: 01900, loss:0.52110, acc: 0.600
epoch: 02000, loss:0.51199, acc: 0.650
epoch: 02100, loss:0.50311, acc: 0.650
epoch: 02200, loss:0.49444, acc: 0.700
epoch: 02300, loss:0.48598, acc: 0.700
epoch: 02400, loss:0.47773, acc: 0.700
epoch: 02500, loss:0.46968, acc: 0.700
epoch: 02600, loss:0.46183, acc: 0.700
epoch: 02700, loss:0.45417, acc: 0.700
epoch: 02800, loss:0.44670, acc: 0.700
epoch: 02900, loss:0.43942, acc: 0.700
epoch: 03000, loss:0.43232, acc: 0.700
epoch: 03100, loss:0.42539, acc: 0.700
epoch: 03200, loss:0.41863, acc: 0.700
epoch: 03300, loss:0.41204, acc: 0.700
epoch: 03400, loss:0.40561, acc: 0.750
epoch: 03500, loss:0.39935, acc: 0.800
epoch: 03600, loss:0.39324, acc: 0.800
epoch: 03700, loss:0.38728, acc: 0.800
epoch: 03800, loss:0.38146, acc: 0.800
epoch: 03900, loss:0.37579, acc: 0.800
epoch: 04000, loss:0.37026, acc: 0.800
epoch: 04100, loss:0.36487, acc: 0.800
epoch: 04200, loss:0.35961, acc: 0.800
epoch: 04300, loss:0.35447, acc: 0.800
epoch: 04400, loss:0.34947, acc: 0.800
epoch: 04500, loss:0.34458, acc: 0.800
epoch: 04600, loss:0.33982, acc: 0.800
epoch: 04700, loss:0.33516, acc: 0.800
epoch: 04800, loss:0.33063, acc: 0.800
epoch: 04900, loss:0.32620, acc: 0.850
epoch: 05000, loss:0.32187, acc: 0.850
epoch: 05100, loss:0.31765, acc: 0.850
epoch: 05200, loss:0.31353, acc: 0.850
epoch: 05300, loss:0.30951, acc: 0.850
epoch: 05400, loss:0.30559, acc: 0.850
epoch: 05500, loss:0.30175, acc: 0.850
epoch: 05600, loss:0.29801, acc: 0.850
epoch: 05700, loss:0.29435, acc: 0.850
epoch: 05800, loss:0.29078, acc: 0.850
epoch: 05900, loss:0.28729, acc: 0.850
epoch: 06000, loss:0.28388, acc: 0.850
epoch: 06100, loss:0.28055, acc: 0.850
epoch: 06200, loss:0.27730, acc: 0.850
epoch: 06300, loss:0.27412, acc: 0.850
epoch: 06400, loss:0.27101, acc: 0.850
epoch: 06500, loss:0.26797, acc: 0.850
epoch: 06600, loss:0.26500, acc: 0.850
epoch: 06700, loss:0.26209, acc: 0.850
epoch: 06800, loss:0.25925, acc: 0.850
epoch: 06900, loss:0.25647, acc: 0.850
epoch: 07000, loss:0.25376, acc: 0.850
epoch: 07100, loss:0.25110, acc: 0.850
epoch: 07200, loss:0.24850, acc: 0.850
epoch: 07300, loss:0.24596, acc: 0.850
epoch: 07400, loss:0.24347, acc: 0.900
epoch: 07500, loss:0.24103, acc: 0.900
epoch: 07600, loss:0.23865, acc: 0.900
epoch: 07700, loss:0.23631, acc: 0.900
epoch: 07800, loss:0.23403, acc: 0.900
epoch: 07900, loss:0.23179, acc: 0.900
epoch: 08000, loss:0.22960, acc: 0.900
epoch: 08100, loss:0.22746, acc: 0.950
epoch: 08200, loss:0.22535, acc: 0.950
epoch: 08300, loss:0.22330, acc: 0.950
epoch: 08400, loss:0.22128, acc: 0.950
epoch: 08500, loss:0.21930, acc: 0.950
epoch: 08600, loss:0.21737, acc: 0.950
epoch: 08700, loss:0.21547, acc: 0.950
epoch: 08800, loss:0.21361, acc: 0.950
epoch: 08900, loss:0.21179, acc: 0.950
epoch: 09000, loss:0.21000, acc: 0.950
epoch: 09100, loss:0.20825, acc: 0.950
epoch: 09200, loss:0.20653, acc: 0.950
epoch: 09300, loss:0.20484, acc: 0.950
epoch: 09400, loss:0.20319, acc: 0.950
epoch: 09500, loss:0.20157, acc: 0.950
epoch: 09600, loss:0.19998, acc: 0.950
epoch: 09700, loss:0.19842, acc: 0.950
epoch: 09800, loss:0.19689, acc: 0.950
epoch: 09900, loss:0.19538, acc: 0.950
epoch: 10000, loss:0.19391, acc: 0.950
epoch: 10100, loss:0.19246, acc: 0.950
epoch: 10200, loss:0.19104, acc: 0.950
epoch: 10300, loss:0.18964, acc: 0.950
epoch: 10400, loss:0.18827, acc: 0.950
epoch: 10500, loss:0.18693, acc: 0.950
epoch: 10600, loss:0.18561, acc: 0.950
epoch: 10700, loss:0.18431, acc: 0.950
epoch: 10800, loss:0.18303, acc: 0.950
epoch: 10900, loss:0.18178, acc: 0.950
epoch: 11000, loss:0.18055, acc: 0.950
epoch: 11100, loss:0.17934, acc: 0.950
epoch: 11200, loss:0.17815, acc: 0.950
epoch: 11300, loss:0.17698, acc: 0.950
epoch: 11400, loss:0.17583, acc: 0.950
epoch: 11500, loss:0.17471, acc: 0.950
epoch: 11600, loss:0.17360, acc: 0.950
epoch: 11700, loss:0.17250, acc: 0.950
epoch: 11800, loss:0.17143, acc: 0.950
epoch: 11900, loss:0.17038, acc: 0.950
epoch: 12000, loss:0.16934, acc: 0.950
epoch: 12100, loss:0.16832, acc: 0.950
epoch: 12200, loss:0.16731, acc: 0.950
epoch: 12300, loss:0.16632, acc: 0.950
epoch: 12400, loss:0.16535, acc: 0.950
epoch: 12500, loss:0.16440, acc: 0.950
epoch: 12600, loss:0.16345, acc: 0.950
epoch: 12700, loss:0.16253, acc: 0.950
epoch: 12800, loss:0.16162, acc: 0.950
epoch: 12900, loss:0.16072, acc: 0.950
epoch: 13000, loss:0.15983, acc: 0.950
epoch: 13100, loss:0.15896, acc: 0.950
epoch: 13200, loss:0.15811, acc: 0.950
epoch: 13300, loss:0.15726, acc: 0.950
epoch: 13400, loss:0.15643, acc: 0.950
epoch: 13500, loss:0.15561, acc: 0.950
epoch: 13600, loss:0.15481, acc: 0.950
epoch: 13700, loss:0.15401, acc: 0.950
epoch: 13800, loss:0.15323, acc: 0.950
epoch: 13900, loss:0.15246, acc: 0.950
epoch: 14000, loss:0.15170, acc: 0.950
epoch: 14100, loss:0.15095, acc: 0.950
epoch: 14200, loss:0.15022, acc: 0.950
epoch: 14300, loss:0.14949, acc: 0.950
epoch: 14400, loss:0.14877, acc: 0.950
epoch: 14500, loss:0.14807, acc: 0.950
epoch: 14600, loss:0.14737, acc: 0.950
epoch: 14700, loss:0.14669, acc: 0.950
epoch: 14800, loss:0.14601, acc: 0.950
epoch: 14900, loss:0.14534, acc: 0.950
epoch: 15000, loss:0.14468, acc: 0.950
epoch: 15100, loss:0.14404, acc: 0.950
epoch: 15200, loss:0.14340, acc: 0.950
epoch: 15300, loss:0.14276, acc: 0.950
epoch: 15400, loss:0.14214, acc: 0.950
epoch: 15500, loss:0.14153, acc: 0.950
epoch: 15600, loss:0.14092, acc: 0.950
epoch: 15700, loss:0.14032, acc: 0.950
epoch: 15800, loss:0.13973, acc: 0.950
epoch: 15900, loss:0.13915, acc: 0.950
epoch: 16000, loss:0.13858, acc: 0.950
epoch: 16100, loss:0.13801, acc: 0.950
epoch: 16200, loss:0.13745, acc: 0.950
epoch: 16300, loss:0.13690, acc: 0.950
epoch: 16400, loss:0.13635, acc: 0.950
epoch: 16500, loss:0.13581, acc: 0.950
epoch: 16600, loss:0.13528, acc: 0.950
epoch: 16700, loss:0.13476, acc: 0.950
epoch: 16800, loss:0.13424, acc: 0.950
epoch: 16900, loss:0.13373, acc: 0.950
epoch: 17000, loss:0.13322, acc: 0.950
epoch: 17100, loss:0.13272, acc: 0.950
epoch: 17200, loss:0.13223, acc: 0.950
epoch: 17300, loss:0.13174, acc: 0.950
epoch: 17400, loss:0.13126, acc: 0.950
epoch: 17500, loss:0.13079, acc: 0.950
epoch: 17600, loss:0.13032, acc: 0.950
epoch: 17700, loss:0.12985, acc: 0.950
epoch: 17800, loss:0.12940, acc: 0.950
epoch: 17900, loss:0.12894, acc: 0.950
epoch: 18000, loss:0.12850, acc: 0.950
epoch: 18100, loss:0.12805, acc: 0.950
epoch: 18200, loss:0.12762, acc: 0.950
epoch: 18300, loss:0.12719, acc: 0.950
epoch: 18400, loss:0.12676, acc: 0.950
epoch: 18500, loss:0.12634, acc: 0.950
epoch: 18600, loss:0.12592, acc: 0.950
epoch: 18700, loss:0.12551, acc: 0.950
epoch: 18800, loss:0.12510, acc: 0.950
epoch: 18900, loss:0.12470, acc: 0.950
epoch: 19000, loss:0.12430, acc: 0.950
epoch: 19100, loss:0.12390, acc: 0.950
epoch: 19200, loss:0.12351, acc: 0.950
epoch: 19300, loss:0.12313, acc: 0.950
epoch: 19400, loss:0.12275, acc: 0.950
epoch: 19500, loss:0.12237, acc: 0.950
epoch: 19600, loss:0.12200, acc: 0.950
epoch: 19700, loss:0.12163, acc: 0.950
epoch: 19800, loss:0.12126, acc: 0.950
epoch: 19900, loss:0.12090, acc: 0.950
OrderedDict([('0.weight', tensor([[-1.4480, -2.8220, 2.5209, 5.8656]])),
('0.bias', tensor([-5.3296]))])
训练损失和测试集预测准确率曲线: