Pytorch与视觉竞赛入门2.2-PyTorch常见的损失函数和优化器使用

PyTorch常见的损失函数和优化器使用

损失函数

损失函数

第三章 PyTorch的主要组成模块/3.5 损失函数.md

torch.nn-loss function

如何选择loss function

参考:深度学习中常见的激活函数与损失函数的选择与介绍

问题类型 最后一层激活函数 损失函数
二分类问题 sigmoid binary_crossentropy
多分类、单标签问题 softmax categorical_crossentropy
多分类、多标签问题 sigmoid binary_crossentropy
回归到任意值 mse
回归到 0~1 范围内的值 sigmoid mse 或 binary_crossentropy

优化器

参考:TORCH.OPTIM

实操

回归任务选MSEloss,选用Adam优化器,用ExponentialLR(每个 epoch 用 gamma 衰减每个参数组的学习率)和MultiStepLR(一旦 epoch 数达到里程碑之一,每个参数组的学习率就会衰减 gamma)

大多数学习率调度器都是链式调度器,即每个调度器都被一个接一个地应用于前一个调度器获得的学习率。

可以使用scheduler1.get_lr()查看学习率。

import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torch.autograd import Variable


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 1) # an affine operation: y = Wx + b
def forward(self, x):
x=self.fc1(x)
return x
net = Net()
net

'''神经网络的结构是这样的
Net (
(fc1): Linear (1 -> 1)
)
'''
x = torch.arange(0, 100, 0.01,dtype=torch.float32)
y = (10 * x + 5 + np.random.normal(0, 1, x.size())).float()

batch_size = 100
w = torch.randn((1,), requires_grad=True,dtype=torch.float32)
b = torch.randn((1,), requires_grad=True,dtype=torch.float32)

iter_time=100
num_epochs = 5
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 0.1)
scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(num_epochs):
for t in range(iter_time):
train_x=x[batch_size*t:batch_size*(t+1)].clone()
train_y = y[batch_size * t:batch_size * (t + 1)].clone()
train_x=train_x.unsqueeze(1)
train_y=train_y.unsqueeze(1)
inputs = Variable(train_x)
target = Variable(train_y)

# forward
out = net(inputs) # 前向传播
loss = criterion(out, target) # 计算loss
# backward
optimizer.zero_grad() # 梯度归零
loss.backward() # 方向传播
optimizer.step() # 更新参数

print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,t+1,loss.data))
scheduler1.step()
scheduler2.step()
#学习率为0.1
#w=9.9628
#b=8.8283
#学习率为0.5
#w=9.9656
#b=8.0902
#学习率为0.01
#w=3.9123
#b=3.0882