PyTorch 逻辑回归,数据集:UCI Iris Data Set

导入依赖包

1
2
3
4
5
6
7
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
1
torch.__version__, torch.manual_seed(33)
('1.12.1+cu102', <torch._C.Generator at 0x7f753ca60e90>)

加载数据集

1
path = "/workspace/disk1/datasets/scalar/iris.data.csv"
1
2
data = pd.read_csv(path)
data.head(3)

5.1 3.5 1.4 0.2 Iris-setosa
0 4.9 3.0 1.4 0.2 Iris-setosa
1 4.7 3.2 1.3 0.2 Iris-setosa
2 4.6 3.1 1.5 0.2 Iris-setosa
1
len(data)
149
1
set(data["Iris-setosa"].values)
{'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'}
1
2
3
data_setosa = data[data["Iris-setosa"] == "Iris-setosa"]
data_versicolor = data[data["Iris-setosa"] == "Iris-versicolor"]
data_virginica = data[data["Iris-setosa"] == "Iris-virginica"]
1
len(data_setosa), len(data_versicolor), len(data_virginica)
(49, 50, 50)
1
2
3
4
data_using = data[data["Iris-setosa"] != "Iris-setosa"]
X = data_using.iloc[:, :-1]
Y = data_using.iloc[:, -1].replace("Iris-versicolor", 0).replace("Iris-virginica", 1)
Y.unique()
array([0, 1])
1
2
X = torch.from_numpy(X.values).type(torch.float32)
X.shape
torch.Size([100, 4])
1
2
Y = torch.from_numpy(Y.values.reshape(-1, 1)).type(torch.float32)
Y.shape
torch.Size([100, 1])
1
2
3
4
5
6
train_set_ratio = 0.8
train_set_num = int(train_set_ratio * X.shape[0])
X_train = X[:train_set_num]
Y_train = Y[:train_set_num]
X_test = X[train_set_num:]
Y_test = Y[train_set_num:]

构建模型

1
2
3
4
model = nn.Sequential(
nn.Linear(in_features=X.shape[1], out_features=Y.shape[1]),
nn.Sigmoid(),
)
1
model
Sequential(
  (0): Linear(in_features=4, out_features=1, bias=True)
  (1): Sigmoid()
)
1
loss_fn = nn.BCELoss()
1
optimizer = optim.Adam(model.parameters(), lr=0.001)
1
2
3
batch = 20
num_of_batch = X_train.shape[0] // batch
epochs = 20000
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况
viz = visdom.Visdom(
server="http://localhost",
port=8097,
base_url="/visdom",
username="jinzhongxu",
password="123123",
)
win = "logistic regression loss"
opts = dict(
title="train_losses",
xlabel="epoch",
ylabel="loss",
markers=True,
legend=["loss", "acc"],
)
viz.line(
[
[
0.0,
0.0,
]
],
[0.0],
win=win,
opts=opts,
)
Setting up a new session...





'logistic regression loss'
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
for epoch in range(epochs):
losses = []
for n in range(num_of_batch):
start = n * batch
end = (n + 1) * batch
x = X_train[start:end]
y = Y_train[start:end]
y_pred = model(x)
loss = loss_fn(y_pred, y)
losses.append(loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch % 100 == 0:
acc = (
(model(X_test).data.numpy() > 0.5).astype(np.int8) == Y_test.numpy()
).mean()
print(
f"epoch: {str(epoch).rjust(len(str(epochs)), '0')}, loss:{np.mean(losses):.5f}, acc: {acc:.3f}"
)
viz.line(
[
[
np.mean(losses),
acc,
]
],
[epoch],
win=win,
update="append",
)
epoch: 00000, loss:0.81945, acc: 0.000
epoch: 00100, loss:0.72729, acc: 0.000
epoch: 00200, loss:0.71364, acc: 0.000
epoch: 00300, loss:0.70016, acc: 0.000
epoch: 00400, loss:0.68695, acc: 0.000
epoch: 00500, loss:0.67402, acc: 0.000
epoch: 00600, loss:0.66138, acc: 0.000
epoch: 00700, loss:0.64902, acc: 0.000
epoch: 00800, loss:0.63694, acc: 0.000
epoch: 00900, loss:0.62514, acc: 0.000
epoch: 01000, loss:0.61361, acc: 0.000
epoch: 01100, loss:0.60233, acc: 0.000
epoch: 01200, loss:0.59132, acc: 0.100
epoch: 01300, loss:0.58056, acc: 0.250
epoch: 01400, loss:0.57005, acc: 0.350
epoch: 01500, loss:0.55979, acc: 0.400
epoch: 01600, loss:0.54976, acc: 0.450
epoch: 01700, loss:0.53998, acc: 0.450
epoch: 01800, loss:0.53042, acc: 0.500
epoch: 01900, loss:0.52110, acc: 0.600
epoch: 02000, loss:0.51199, acc: 0.650
epoch: 02100, loss:0.50311, acc: 0.650
epoch: 02200, loss:0.49444, acc: 0.700
epoch: 02300, loss:0.48598, acc: 0.700
epoch: 02400, loss:0.47773, acc: 0.700
epoch: 02500, loss:0.46968, acc: 0.700
epoch: 02600, loss:0.46183, acc: 0.700
epoch: 02700, loss:0.45417, acc: 0.700
epoch: 02800, loss:0.44670, acc: 0.700
epoch: 02900, loss:0.43942, acc: 0.700
epoch: 03000, loss:0.43232, acc: 0.700
epoch: 03100, loss:0.42539, acc: 0.700
epoch: 03200, loss:0.41863, acc: 0.700
epoch: 03300, loss:0.41204, acc: 0.700
epoch: 03400, loss:0.40561, acc: 0.750
epoch: 03500, loss:0.39935, acc: 0.800
epoch: 03600, loss:0.39324, acc: 0.800
epoch: 03700, loss:0.38728, acc: 0.800
epoch: 03800, loss:0.38146, acc: 0.800
epoch: 03900, loss:0.37579, acc: 0.800
epoch: 04000, loss:0.37026, acc: 0.800
epoch: 04100, loss:0.36487, acc: 0.800
epoch: 04200, loss:0.35961, acc: 0.800
epoch: 04300, loss:0.35447, acc: 0.800
epoch: 04400, loss:0.34947, acc: 0.800
epoch: 04500, loss:0.34458, acc: 0.800
epoch: 04600, loss:0.33982, acc: 0.800
epoch: 04700, loss:0.33516, acc: 0.800
epoch: 04800, loss:0.33063, acc: 0.800
epoch: 04900, loss:0.32620, acc: 0.850
epoch: 05000, loss:0.32187, acc: 0.850
epoch: 05100, loss:0.31765, acc: 0.850
epoch: 05200, loss:0.31353, acc: 0.850
epoch: 05300, loss:0.30951, acc: 0.850
epoch: 05400, loss:0.30559, acc: 0.850
epoch: 05500, loss:0.30175, acc: 0.850
epoch: 05600, loss:0.29801, acc: 0.850
epoch: 05700, loss:0.29435, acc: 0.850
epoch: 05800, loss:0.29078, acc: 0.850
epoch: 05900, loss:0.28729, acc: 0.850
epoch: 06000, loss:0.28388, acc: 0.850
epoch: 06100, loss:0.28055, acc: 0.850
epoch: 06200, loss:0.27730, acc: 0.850
epoch: 06300, loss:0.27412, acc: 0.850
epoch: 06400, loss:0.27101, acc: 0.850
epoch: 06500, loss:0.26797, acc: 0.850
epoch: 06600, loss:0.26500, acc: 0.850
epoch: 06700, loss:0.26209, acc: 0.850
epoch: 06800, loss:0.25925, acc: 0.850
epoch: 06900, loss:0.25647, acc: 0.850
epoch: 07000, loss:0.25376, acc: 0.850
epoch: 07100, loss:0.25110, acc: 0.850
epoch: 07200, loss:0.24850, acc: 0.850
epoch: 07300, loss:0.24596, acc: 0.850
epoch: 07400, loss:0.24347, acc: 0.900
epoch: 07500, loss:0.24103, acc: 0.900
epoch: 07600, loss:0.23865, acc: 0.900
epoch: 07700, loss:0.23631, acc: 0.900
epoch: 07800, loss:0.23403, acc: 0.900
epoch: 07900, loss:0.23179, acc: 0.900
epoch: 08000, loss:0.22960, acc: 0.900
epoch: 08100, loss:0.22746, acc: 0.950
epoch: 08200, loss:0.22535, acc: 0.950
epoch: 08300, loss:0.22330, acc: 0.950
epoch: 08400, loss:0.22128, acc: 0.950
epoch: 08500, loss:0.21930, acc: 0.950
epoch: 08600, loss:0.21737, acc: 0.950
epoch: 08700, loss:0.21547, acc: 0.950
epoch: 08800, loss:0.21361, acc: 0.950
epoch: 08900, loss:0.21179, acc: 0.950
epoch: 09000, loss:0.21000, acc: 0.950
epoch: 09100, loss:0.20825, acc: 0.950
epoch: 09200, loss:0.20653, acc: 0.950
epoch: 09300, loss:0.20484, acc: 0.950
epoch: 09400, loss:0.20319, acc: 0.950
epoch: 09500, loss:0.20157, acc: 0.950
epoch: 09600, loss:0.19998, acc: 0.950
epoch: 09700, loss:0.19842, acc: 0.950
epoch: 09800, loss:0.19689, acc: 0.950
epoch: 09900, loss:0.19538, acc: 0.950
epoch: 10000, loss:0.19391, acc: 0.950
epoch: 10100, loss:0.19246, acc: 0.950
epoch: 10200, loss:0.19104, acc: 0.950
epoch: 10300, loss:0.18964, acc: 0.950
epoch: 10400, loss:0.18827, acc: 0.950
epoch: 10500, loss:0.18693, acc: 0.950
epoch: 10600, loss:0.18561, acc: 0.950
epoch: 10700, loss:0.18431, acc: 0.950
epoch: 10800, loss:0.18303, acc: 0.950
epoch: 10900, loss:0.18178, acc: 0.950
epoch: 11000, loss:0.18055, acc: 0.950
epoch: 11100, loss:0.17934, acc: 0.950
epoch: 11200, loss:0.17815, acc: 0.950
epoch: 11300, loss:0.17698, acc: 0.950
epoch: 11400, loss:0.17583, acc: 0.950
epoch: 11500, loss:0.17471, acc: 0.950
epoch: 11600, loss:0.17360, acc: 0.950
epoch: 11700, loss:0.17250, acc: 0.950
epoch: 11800, loss:0.17143, acc: 0.950
epoch: 11900, loss:0.17038, acc: 0.950
epoch: 12000, loss:0.16934, acc: 0.950
epoch: 12100, loss:0.16832, acc: 0.950
epoch: 12200, loss:0.16731, acc: 0.950
epoch: 12300, loss:0.16632, acc: 0.950
epoch: 12400, loss:0.16535, acc: 0.950
epoch: 12500, loss:0.16440, acc: 0.950
epoch: 12600, loss:0.16345, acc: 0.950
epoch: 12700, loss:0.16253, acc: 0.950
epoch: 12800, loss:0.16162, acc: 0.950
epoch: 12900, loss:0.16072, acc: 0.950
epoch: 13000, loss:0.15983, acc: 0.950
epoch: 13100, loss:0.15896, acc: 0.950
epoch: 13200, loss:0.15811, acc: 0.950
epoch: 13300, loss:0.15726, acc: 0.950
epoch: 13400, loss:0.15643, acc: 0.950
epoch: 13500, loss:0.15561, acc: 0.950
epoch: 13600, loss:0.15481, acc: 0.950
epoch: 13700, loss:0.15401, acc: 0.950
epoch: 13800, loss:0.15323, acc: 0.950
epoch: 13900, loss:0.15246, acc: 0.950
epoch: 14000, loss:0.15170, acc: 0.950
epoch: 14100, loss:0.15095, acc: 0.950
epoch: 14200, loss:0.15022, acc: 0.950
epoch: 14300, loss:0.14949, acc: 0.950
epoch: 14400, loss:0.14877, acc: 0.950
epoch: 14500, loss:0.14807, acc: 0.950
epoch: 14600, loss:0.14737, acc: 0.950
epoch: 14700, loss:0.14669, acc: 0.950
epoch: 14800, loss:0.14601, acc: 0.950
epoch: 14900, loss:0.14534, acc: 0.950
epoch: 15000, loss:0.14468, acc: 0.950
epoch: 15100, loss:0.14404, acc: 0.950
epoch: 15200, loss:0.14340, acc: 0.950
epoch: 15300, loss:0.14276, acc: 0.950
epoch: 15400, loss:0.14214, acc: 0.950
epoch: 15500, loss:0.14153, acc: 0.950
epoch: 15600, loss:0.14092, acc: 0.950
epoch: 15700, loss:0.14032, acc: 0.950
epoch: 15800, loss:0.13973, acc: 0.950
epoch: 15900, loss:0.13915, acc: 0.950
epoch: 16000, loss:0.13858, acc: 0.950
epoch: 16100, loss:0.13801, acc: 0.950
epoch: 16200, loss:0.13745, acc: 0.950
epoch: 16300, loss:0.13690, acc: 0.950
epoch: 16400, loss:0.13635, acc: 0.950
epoch: 16500, loss:0.13581, acc: 0.950
epoch: 16600, loss:0.13528, acc: 0.950
epoch: 16700, loss:0.13476, acc: 0.950
epoch: 16800, loss:0.13424, acc: 0.950
epoch: 16900, loss:0.13373, acc: 0.950
epoch: 17000, loss:0.13322, acc: 0.950
epoch: 17100, loss:0.13272, acc: 0.950
epoch: 17200, loss:0.13223, acc: 0.950
epoch: 17300, loss:0.13174, acc: 0.950
epoch: 17400, loss:0.13126, acc: 0.950
epoch: 17500, loss:0.13079, acc: 0.950
epoch: 17600, loss:0.13032, acc: 0.950
epoch: 17700, loss:0.12985, acc: 0.950
epoch: 17800, loss:0.12940, acc: 0.950
epoch: 17900, loss:0.12894, acc: 0.950
epoch: 18000, loss:0.12850, acc: 0.950
epoch: 18100, loss:0.12805, acc: 0.950
epoch: 18200, loss:0.12762, acc: 0.950
epoch: 18300, loss:0.12719, acc: 0.950
epoch: 18400, loss:0.12676, acc: 0.950
epoch: 18500, loss:0.12634, acc: 0.950
epoch: 18600, loss:0.12592, acc: 0.950
epoch: 18700, loss:0.12551, acc: 0.950
epoch: 18800, loss:0.12510, acc: 0.950
epoch: 18900, loss:0.12470, acc: 0.950
epoch: 19000, loss:0.12430, acc: 0.950
epoch: 19100, loss:0.12390, acc: 0.950
epoch: 19200, loss:0.12351, acc: 0.950
epoch: 19300, loss:0.12313, acc: 0.950
epoch: 19400, loss:0.12275, acc: 0.950
epoch: 19500, loss:0.12237, acc: 0.950
epoch: 19600, loss:0.12200, acc: 0.950
epoch: 19700, loss:0.12163, acc: 0.950
epoch: 19800, loss:0.12126, acc: 0.950
epoch: 19900, loss:0.12090, acc: 0.950
1
model.state_dict()
OrderedDict([('0.weight', tensor([[-1.4480, -2.8220,  2.5209,  5.8656]])),
             ('0.bias', tensor([-5.3296]))])
1

训练损失和测试集预测准确率曲线:
png