PyTorch 多层感知机
PyTorch 多层感知机,数据集来自:nivedithabhandary/HR-Analytics
数据预处理
python
1 | import matplotlib.pyplot as plt |
<torch._C.Generator at 0x7fabe84e8e90>
python
1 | path = "/workspace/disk1/datasets/scalar/HR_comma_sep.csv" |
python
1 | data = pd.read_csv(path) |
satisfaction_level | last_evaluation | number_project | average_montly_hours | time_spend_company | Work_accident | left | promotion_last_5years | sales | salary | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0.38 | 0.53 | 2 | 157 | 3 | 0 | 1 | 0 | sales | low |
1 | 0.80 | 0.86 | 5 | 262 | 6 | 0 | 1 | 0 | sales | medium |
2 | 0.11 | 0.88 | 7 | 272 | 4 | 0 | 1 | 0 | sales | medium |
python
1 | data = data.rename(columns={"sales": "part"}) |
python
1 | data.info() |
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14999 entries, 0 to 14998
Data columns (total 10 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 satisfaction_level 14999 non-null float64
1 last_evaluation 14999 non-null float64
2 number_project 14999 non-null int64
3 average_montly_hours 14999 non-null int64
4 time_spend_company 14999 non-null int64
5 Work_accident 14999 non-null int64
6 left 14999 non-null int64
7 promotion_last_5years 14999 non-null int64
8 part 14999 non-null object
9 salary 14999 non-null object
dtypes: float64(2), int64(6), object(2)
memory usage: 1.1+ MB
python
1 | data.part.unique() |
array(['sales', 'accounting', 'hr', 'technical', 'support', 'management',
'IT', 'product_mng', 'marketing', 'RandD'], dtype=object)
python
1 | data.salary.unique() |
array(['low', 'medium', 'high'], dtype=object)
python
1 | data.groupby(["salary", "part"]).size() |
salary part
high IT 83
RandD 51
accounting 74
hr 45
management 225
marketing 80
product_mng 68
sales 269
support 141
technical 201
low IT 609
RandD 364
accounting 358
hr 335
management 180
marketing 402
product_mng 451
sales 2099
support 1146
technical 1372
medium IT 535
RandD 372
accounting 335
hr 359
management 225
marketing 376
product_mng 383
sales 1772
support 942
technical 1147
dtype: int64
python
1 | data = data.join(pd.get_dummies(data.salary)) |
python
1 | data.head(3) |
satisfaction_level | last_evaluation | number_project | average_montly_hours | time_spend_company | Work_accident | left | promotion_last_5years | high | low | ... | IT | RandD | accounting | hr | management | marketing | product_mng | sales | support | technical | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.38 | 0.53 | 2 | 157 | 3 | 0 | 1 | 0 | 0 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
1 | 0.80 | 0.86 | 5 | 262 | 6 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
2 | 0.11 | 0.88 | 7 | 272 | 4 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
3 rows × 21 columns
python
1 | data.columns |
Index(['satisfaction_level', 'last_evaluation', 'number_project',
'average_montly_hours', 'time_spend_company', 'Work_accident', 'left',
'promotion_last_5years', 'high', 'low', 'medium', 'IT', 'RandD',
'accounting', 'hr', 'management', 'marketing', 'product_mng', 'sales',
'support', 'technical'],
dtype='object')
python
1 | data.shape |
(14999, 21)
python
1 | data.left.unique() |
array([1, 0])
python
1 | # 数据不均衡 |
0 11428
1 3571
Name: left, dtype: int64
python
1 | data.left.value_counts() / len(data) |
0 0.761917
1 0.238083
Name: left, dtype: float64
python
1 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
python
1 | Y_data = data.left.values.reshape(-1, 1) |
python
1 | # 训练集和验证集划分 |
python
1 | HRdataset = TensorDataset(X_train_, Y_train_) |
创建模型
python
1 | import torch.nn.functional as F |
python
1 | class MLPModel(nn.Module): |
python
1 | model = MLPModel() |
MLPModel(
(linear_1): Linear(in_features=20, out_features=64, bias=True)
(linear_2): Linear(in_features=64, out_features=64, bias=True)
(linear_3): Linear(in_features=64, out_features=1, bias=True)
)
python
1 | optimizer = optim.Adam(model.parameters(), lr=0.0001) |
python
1 | loss_fn = nn.BCELoss() |
python
1 | # 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况 |
Setting up a new session...
'Multilayer Perceptron loss'
python
1 | batch = 64 |
epoch: 0, loss:0.5702827572822571
epoch: 10, loss:0.5380138754844666
epoch: 20, loss:0.48325881361961365
epoch: 30, loss:0.43280646204948425
epoch: 40, loss:0.40077832341194153
epoch: 50, loss:0.3730575442314148
epoch: 60, loss:0.3515182137489319
epoch: 70, loss:0.333321213722229
epoch: 80, loss:0.3213912844657898
epoch: 90, loss:0.3078565299510956
epoch: 100, loss:0.3066919445991516
epoch: 110, loss:0.28822797536849976
epoch: 120, loss:0.2822719216346741
epoch: 130, loss:0.2698499262332916
epoch: 140, loss:0.2673708200454712
epoch: 150, loss:0.2626533508300781
epoch: 160, loss:0.25525325536727905
epoch: 170, loss:0.25129374861717224
epoch: 180, loss:0.2505483329296112
epoch: 190, loss:0.23713809251785278
epoch: 200, loss:0.24216558039188385
epoch: 210, loss:0.22742848098278046
epoch: 220, loss:0.2271704226732254
epoch: 230, loss:0.2252519577741623
epoch: 240, loss:0.2193305790424347
epoch: 250, loss:0.21628950536251068
epoch: 260, loss:0.2177349030971527
epoch: 270, loss:0.2127998024225235
epoch: 280, loss:0.21056963503360748
epoch: 290, loss:0.2106262445449829
epoch: 300, loss:0.2119056135416031
epoch: 310, loss:0.21744157373905182
epoch: 320, loss:0.2134646326303482
epoch: 330, loss:0.20614351332187653
epoch: 340, loss:0.21047110855579376
epoch: 350, loss:0.20926110446453094
epoch: 360, loss:0.21647755801677704
epoch: 370, loss:0.21862877905368805
epoch: 380, loss:0.20591312646865845
epoch: 390, loss:0.2010645568370819
epoch: 400, loss:0.2050144225358963
epoch: 410, loss:0.1988362967967987
epoch: 420, loss:0.2004001885652542
epoch: 430, loss:0.1973736584186554
epoch: 440, loss:0.22496183216571808
epoch: 450, loss:0.19971159100532532
epoch: 460, loss:0.19692614674568176
epoch: 470, loss:0.19598537683486938
epoch: 480, loss:0.1948203295469284
epoch: 490, loss:0.19280420243740082
epoch: 500, loss:0.19459198415279388
epoch: 510, loss:0.19615624845027924
epoch: 520, loss:0.19460928440093994
epoch: 530, loss:0.19126252830028534
epoch: 540, loss:0.1938866376876831
epoch: 550, loss:0.19190427660942078
epoch: 560, loss:0.18747270107269287
epoch: 570, loss:0.18757325410842896
epoch: 580, loss:0.1988757699728012
epoch: 590, loss:0.19326147437095642
epoch: 600, loss:0.20158174633979797
epoch: 610, loss:0.18407107889652252
epoch: 620, loss:0.1833522766828537
epoch: 630, loss:0.18116755783557892
epoch: 640, loss:0.1835246980190277
epoch: 650, loss:0.18026278913021088
epoch: 660, loss:0.1771889179944992
epoch: 670, loss:0.1763589233160019
epoch: 680, loss:0.17852289974689484
epoch: 690, loss:0.17640434205532074
epoch: 700, loss:0.17527474462985992
epoch: 710, loss:0.18816079199314117
epoch: 720, loss:0.17597466707229614
epoch: 730, loss:0.17061764001846313
epoch: 740, loss:0.1716785877943039
epoch: 750, loss:0.16623708605766296
epoch: 760, loss:0.16821976006031036
epoch: 770, loss:0.163313090801239
epoch: 780, loss:0.1642996072769165
epoch: 790, loss:0.1681404858827591
epoch: 800, loss:0.16040602326393127
epoch: 810, loss:0.1622639298439026
epoch: 820, loss:0.16376478970050812
epoch: 830, loss:0.15491259098052979
epoch: 840, loss:0.15919749438762665
epoch: 850, loss:0.15590602159500122
epoch: 860, loss:0.15480470657348633
epoch: 870, loss:0.1528528481721878
epoch: 880, loss:0.1514061689376831
epoch: 890, loss:0.15681587159633636
epoch: 900, loss:0.1545121818780899
epoch: 910, loss:0.14895710349082947
epoch: 920, loss:0.14539872109889984
epoch: 930, loss:0.1600637137889862
epoch: 940, loss:0.1435800939798355
epoch: 950, loss:0.14283594489097595
epoch: 960, loss:0.14213287830352783
epoch: 970, loss:0.1424899399280548
epoch: 980, loss:0.14058111608028412
epoch: 990, loss:0.1395450234413147
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 J. Xu!
评论