Pytorch与深度学习自查手册4-训练、可视化、日志输出、保存模型

训练和验证(包含可视化、日志、保存模型)

初始化模型、dataloader都完善以后,正式进入训练部分。

训练部分包括:

  1. 及时的日志记录

  2. tensorboard可视化log
  3. 输入

  4. 前向传播
  5. loss计算
  6. 反向传播
  7. 权重更新

  8. 固定步骤进行验证
  9. 最佳模型的保存(+bad case输出)

日志记录

利用logging模块在控制台实时打印并及时记录运行日志。

from config import  *
import logging # 引入logging模块
import os.path
class Logger:
def __init__(self,mode='w'):
# 第一步,创建一个logger
self.logger = logging.getLogger()
self.logger.setLevel(logging.INFO) # Log等级总开关
# 第二步,创建一个handler,用于写入日志文件
rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
log_path = os.getcwd() + '/Logs/'
log_name = log_path + rq + '.log'
logfile = log_name
fh = logging.FileHandler(logfile, mode=mode)
fh.setLevel(logging.DEBUG) # 输出到file的log等级的开关
# 第三步,定义handler的输出格式
formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
fh.setFormatter(formatter)
# 第四步,将logger添加到handler里面
self.logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO) # 输出到console的log等级的开关
ch.setFormatter(formatter)
self.logger.addHandler(ch)

完整的训练流程

import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler
import sys
from tqdm import tqdm
import torch

def train_one_epoch(model, optimizer, data_loader, device, epoch,tb_writer):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
mean_acc = torch.zeros(1).to(device)
optimizer.zero_grad()

data_loader = tqdm(data_loader)
for iteration, data in enumerate(data_loader):
batch, labels = data
pred = model(batch.to(device))

loss = loss_function(pred, labels.to(device))
loss.backward()
mean_loss = (mean_loss * iteration + loss.detach()) / (step + 1) # update mean losses
pred = torch.max(pred, dim=1)[1]
iter_acc=torch.eq(pred, labels.to(device)).sum()
mean_acc+=iter_acc
# 打印平均loss
if iteration % 50 == 0:
data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))

if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)

optimizer.step()
optimizer.zero_grad()
tags=["train_loss","train_accuracy","learning_rate"]
# tensorboard可视化
for tag, value in zip(tags, [mean_loss.item(), iter_acc.item(), optimizer.param_groups[0]["lr"]]):
tb_writer.add_scalars('Train\%s'%tag, value, iteration)
return mean_loss.item(),mean_acc.item()


@torch.no_grad()
def evaluate(model, data_loader, device,best_acc=-1):
model.eval()

# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
# 统计验证集样本总数目
num_samples = len(data_loader.dataset)

# 打印验证进度
data_loader = tqdm(data_loader, desc="validation...")
bad_case=[]
for step, data in enumerate(data_loader):
batch, labels = data
pred = model(batch.to(device))
pred = torch.max(pred, dim=1)[1]
tmp=torch.eq(pred, labels.to(device))
sum_num += tmp.sum()
bad_case.append((batch[~tmp],labels[~tmp]))
# 计算预测正确的比例
acc = sum_num.item() / num_samples
if best_acc<acc:
joblib.dump(bad_case,'bad_case.pkl')
return acc

def main(args,logger):

# 实例化模型
model=Model()
# 是否冻结权重
if args.freeze_layers:
print("freeze layers except fc layer.")
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "fc" not in name:
para.requires_grad_(False)

pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)

lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# 写入日志
logger.logger.info('start training......\n')

tb_writer = SummaryWriter(log_dir=args.log_dir)
# 将模型写入tensorboard
init_input = torch.zeros((1, 3, 224, 224), device=args.device)
tb_writer.add_graph(model, init_input)

best_acc=0#最佳模型的指标
for epoch in range(args.epochs):
mean_loss,mean_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch,
tb_writer=tb_writer)
# update learning rate
scheduler.step()

# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)

# tensorboard可视化
tb_writer.add_scalars('Validation\val_accuracy', acc, epoch)
logger.logger.info('%d epoch train mean loss: %.2f \n'%(epoch,mean_loss))
logger.logger.info('%d epoch train mean acc: %.2f \n'%(epoch,mean_acc))
logger.logger.info('%d epoch validation acc: %.2f \n'%(epoch,acc))
if epoch % args.save_epoch==0:
checkpoint = {
'model_state_dict': model.state_dict(), #*模型参数
'optimizer_state_dict': optimizer.state_dict(), #*优化器参数
'scheduler_state_dict': scheduler.state_dict(), #*scheduler
'epoch': epoch,
'best_val_mae': best_valid_mae,
'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint-%d.pt'%epoch))
logger.logger.info('save model %d successed......\n'epoch)
# 保存最佳模型
if best_acc<acc:
best_acc=acc
logger.logger.info('best model in %d epoch, train mean acc: %.2f \n'%(epoch,mean_acc))
logger.logger.info('best model in %d epoch, validation acc: %.2f \n'%(epoch,acc))
checkpoint = {
'model_state_dict': model.state_dict(), #*模型参数
'optimizer_state_dict': optimizer.state_dict(), #*优化器参数
'scheduler_state_dict': scheduler.state_dict(), #*scheduler
'epoch': epoch,
'best_val_mae': best_valid_mae,
'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'best_checkpoint.pt'))
logger.logger.info('save best model successed......\n')
# 可视化图片预测结果
# add figure into tensorboard
fig = ...
if fig is not None:
tb_writer.add_figure("predictions vs. actuals",
figure=fig,
global_step=epoch)
# 可视化权重不断更新的直方图
# add conv1 weights into tensorboard
tb_writer.add_histogram(tag="conv1",
values=model.conv1.weight,
global_step=epoch)
tb_writer.add_histogram(tag="layer1/block0/conv1",
values=model.layer1[0].conv1.weight,
global_step=epoch)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lrf', type=float, default=0.1)
parser.add_argument('--save_epoch', type=float, default=3)
parser.add_argument('--log_dir', type=float, default=3)
# 数据集所在根目录
data_root = "/home/data_set/"
parser.add_argument('--data-path', type=str, default=img_root)

#--freeze-layers #如果是True表示冻结除了全连接层以外的所有层的参数,在导入一些预训练的模型可以使用,可以加快模型训练
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')

opt = parser.parse_args()

logger=Logger()
main(opt,logger)

模型保存和断点继续训练

在训练模型过程中,对模型进行保存是很重要的。

核心包括两个内容:

  1. 最优模型保存机制;
  2. 如何从断点加载模型继续训练。

最优模型通常通过设置测试间隔,根据选定的指标,选择训练过程中表现最优的模型进行保存。需要保存以下内容到checkpoint:

  1. 核心:模型参数、优化器参数、scheduler参数;
  2. 其他:训练epoch、模型超参数、模型评价指标。

只保存和加载模型的话,还有其他方式可参考:PyTorch之保存加载模型 - 简书 (jianshu.com)

num_params = sum(p.numel() for p in model.parameters())
checkpoint = {
'model_state_dict': model.state_dict(), #*模型参数
'optimizer_state_dict': optimizer.state_dict(), #*优化器参数
'scheduler_state_dict': scheduler.state_dict(), #*scheduler
'epoch': epoch,
'best_val_mae': best_valid_mae,
'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint.pt'))

要实现断点继续训练只需要将模型上次保存的checkpoint加载进来,然后继续训练即可:

path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
lr_schedule = ThelrscheduleClass(*args, **kwargs)
model.load_state_dict(checkpoint['model_state_dict']) # 加载模型可学习参数

optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler

for epoch in range(start_epoch, args.epochs + 1):
#……
train_mae = train(model, device, train_loader, optimizer,scheduler, criterion_fn)
#……

预测

@torch.no_grad()
def predict(model, data_loader, device):
model.eval()

# 用于存储预测正确的样本个数
sum_num = torch.zeros(1).to(device)
# 统计验证集样本总数目
num_samples = len(data_loader.dataset)

# 打印验证进度
data_loader = tqdm(data_loader, desc="validation...")
res=[]
for step, batch in enumerate(data_loader):
pred = model(batch.to(device))
pred = torch.max(pred, dim=1)[1]
res.extend(pred)
return res