import os import math import argparse import torch import torch.optim as optim from torch.utils.tensorboard import SummaryWriter import torch.optim.lr_scheduler as lr_scheduler import sys from tqdm import tqdm import torch
def train_one_epoch(model, optimizer, data_loader, device, epoch,tb_writer): model.train() loss_function = torch.nn.CrossEntropyLoss() mean_loss = torch.zeros(1).to(device) mean_acc = torch.zeros(1).to(device) optimizer.zero_grad()
data_loader = tqdm(data_loader) for iteration, data in enumerate(data_loader): batch, labels = data pred = model(batch.to(device))
loss = loss_function(pred, labels.to(device)) loss.backward() mean_loss = (mean_loss * iteration + loss.detach()) / (step + 1) pred = torch.max(pred, dim=1)[1] iter_acc=torch.eq(pred, labels.to(device)).sum() mean_acc+=iter_acc if iteration % 50 == 0: data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))
if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss) sys.exit(1)
optimizer.step() optimizer.zero_grad() tags=["train_loss","train_accuracy","learning_rate"] for tag, value in zip(tags, [mean_loss.item(), iter_acc.item(), optimizer.param_groups[0]["lr"]]): tb_writer.add_scalars('Train\%s'%tag, value, iteration) return mean_loss.item(),mean_acc.item()
@torch.no_grad() def evaluate(model, data_loader, device,best_acc=-1): model.eval()
sum_num = torch.zeros(1).to(device) num_samples = len(data_loader.dataset)
data_loader = tqdm(data_loader, desc="validation...") bad_case=[] for step, data in enumerate(data_loader): batch, labels = data pred = model(batch.to(device)) pred = torch.max(pred, dim=1)[1] tmp=torch.eq(pred, labels.to(device)) sum_num += tmp.sum() bad_case.append((batch[~tmp],labels[~tmp])) acc = sum_num.item() / num_samples if best_acc<acc: joblib.dump(bad_case,'bad_case.pkl') return acc
def main(args,logger): model=Model() if args.freeze_layers: print("freeze layers except fc layer.") for name, para in model.named_parameters(): if "fc" not in name: para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad] optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005) lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) logger.logger.info('start training......\n') tb_writer = SummaryWriter(log_dir=args.log_dir) init_input = torch.zeros((1, 3, 224, 224), device=args.device) tb_writer.add_graph(model, init_input) best_acc=0 for epoch in range(args.epochs): mean_loss,mean_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch, tb_writer=tb_writer) scheduler.step()
acc = evaluate(model=model, data_loader=val_loader, device=device)
tb_writer.add_scalars('Validation\val_accuracy', acc, epoch) logger.logger.info('%d epoch train mean loss: %.2f \n'%(epoch,mean_loss)) logger.logger.info('%d epoch train mean acc: %.2f \n'%(epoch,mean_acc)) logger.logger.info('%d epoch validation acc: %.2f \n'%(epoch,acc)) if epoch % args.save_epoch==0: checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch, 'best_val_mae': best_valid_mae, 'num_params': num_params } torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint-%d.pt'%epoch)) logger.logger.info('save model %d successed......\n'epoch) if best_acc<acc: best_acc=acc logger.logger.info('best model in %d epoch, train mean acc: %.2f \n'%(epoch,mean_acc)) logger.logger.info('best model in %d epoch, validation acc: %.2f \n'%(epoch,acc)) checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch, 'best_val_mae': best_valid_mae, 'num_params': num_params } torch.save(checkpoint, os.path.join(args.save_dir, 'best_checkpoint.pt')) logger.logger.info('save best model successed......\n') fig = ... if fig is not None: tb_writer.add_figure("predictions vs. actuals", figure=fig, global_step=epoch) tb_writer.add_histogram(tag="conv1", values=model.conv1.weight, global_step=epoch) tb_writer.add_histogram(tag="layer1/block0/conv1", values=model.layer1[0].conv1.weight, global_step=epoch) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--num_classes', type=int, default=5) parser.add_argument('--epochs', type=int, default=30) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lrf', type=float, default=0.1) parser.add_argument('--save_epoch', type=float, default=3) parser.add_argument('--log_dir', type=float, default=3) data_root = "/home/data_set/" parser.add_argument('--data-path', type=str, default=img_root)
parser.add_argument('--freeze-layers', type=bool, default=False) parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args() logger=Logger() main(opt,logger)
|