292 lines
12 KiB
Python
292 lines
12 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
import random
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
import torch.optim as optim
|
|
import torch.utils.data
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
import numpy as np
|
|
|
|
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
|
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
|
|
from model import Model
|
|
from test import validation
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
def count_parameters(model):
|
|
print("Modules, Parameters")
|
|
total_params = 0
|
|
for name, parameter in model.named_parameters():
|
|
if not parameter.requires_grad: continue
|
|
param = parameter.numel()
|
|
#table.add_row([name, param])
|
|
total_params+=param
|
|
print(name, param)
|
|
print(f"Total Trainable Params: {total_params}")
|
|
return total_params
|
|
|
|
def train(opt, show_number = 2, amp=False):
|
|
""" dataset preparation """
|
|
if not opt.data_filtering_off:
|
|
print('Filtering the images containing characters which are not in opt.character')
|
|
print('Filtering the images whose label is longer than opt.batch_max_length')
|
|
|
|
opt.select_data = opt.select_data.split('-')
|
|
opt.batch_ratio = opt.batch_ratio.split('-')
|
|
print(f"opt.select_data: {opt.select_data}")
|
|
print(f"opt.batch_ratio: {opt.batch_ratio}")
|
|
|
|
train_dataset = Batch_Balanced_Dataset(opt)
|
|
|
|
log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8")
|
|
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust)
|
|
print(f"opt.valid_data : {opt.valid_data}")
|
|
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
|
|
|
|
valid_loader = torch.utils.data.DataLoader(
|
|
valid_dataset, batch_size=min(32, opt.batch_size),
|
|
shuffle=True, # 'True' to check training progress with validation function.
|
|
num_workers=int(opt.workers), prefetch_factor=512,
|
|
collate_fn=AlignCollate_valid, pin_memory=True)
|
|
log.write(valid_dataset_log)
|
|
print('-' * 80)
|
|
log.write('-' * 80 + '\n')
|
|
log.close()
|
|
|
|
""" model configuration """
|
|
if 'CTC' in opt.Prediction:
|
|
converter = CTCLabelConverter(opt.character)
|
|
else:
|
|
converter = AttnLabelConverter(opt.character)
|
|
opt.num_class = len(converter.character)
|
|
|
|
if opt.rgb:
|
|
opt.input_channel = 3
|
|
model = Model(opt)
|
|
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
|
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
|
opt.SequenceModeling, opt.Prediction)
|
|
|
|
if opt.saved_model != '':
|
|
pretrained_dict = torch.load(opt.saved_model)
|
|
if opt.new_prediction:
|
|
model.Prediction = nn.Linear(model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight']))
|
|
|
|
model = torch.nn.DataParallel(model).to(device)
|
|
print(f'loading pretrained model from {opt.saved_model}')
|
|
if opt.FT:
|
|
model.load_state_dict(pretrained_dict, strict=False)
|
|
else:
|
|
model.load_state_dict(pretrained_dict)
|
|
if opt.new_prediction:
|
|
model.module.Prediction = nn.Linear(model.module.SequenceModeling_output, opt.num_class)
|
|
for name, param in model.module.Prediction.named_parameters():
|
|
if 'bias' in name:
|
|
init.constant_(param, 0.0)
|
|
elif 'weight' in name:
|
|
init.kaiming_normal_(param)
|
|
model = model.to(device)
|
|
else:
|
|
# weight initialization
|
|
for name, param in model.named_parameters():
|
|
if 'localization_fc2' in name:
|
|
print(f'Skip {name} as it is already initialized')
|
|
continue
|
|
try:
|
|
if 'bias' in name:
|
|
init.constant_(param, 0.0)
|
|
elif 'weight' in name:
|
|
init.kaiming_normal_(param)
|
|
except Exception as e: # for batchnorm.
|
|
if 'weight' in name:
|
|
param.data.fill_(1)
|
|
continue
|
|
model = torch.nn.DataParallel(model).to(device)
|
|
|
|
model.train()
|
|
print("Model:")
|
|
print(model)
|
|
count_parameters(model)
|
|
|
|
""" setup loss """
|
|
if 'CTC' in opt.Prediction:
|
|
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
|
|
else:
|
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
|
|
# loss averager
|
|
loss_avg = Averager()
|
|
|
|
# freeze some layers
|
|
try:
|
|
if opt.freeze_FeatureFxtraction:
|
|
for param in model.module.FeatureExtraction.parameters():
|
|
param.requires_grad = False
|
|
if opt.freeze_SequenceModeling:
|
|
for param in model.module.SequenceModeling.parameters():
|
|
param.requires_grad = False
|
|
except:
|
|
pass
|
|
|
|
# filter that only require gradient decent
|
|
filtered_parameters = []
|
|
params_num = []
|
|
for p in filter(lambda p: p.requires_grad, model.parameters()):
|
|
filtered_parameters.append(p)
|
|
params_num.append(np.prod(p.size()))
|
|
print('Trainable params num : ', sum(params_num))
|
|
# [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
|
|
|
|
# setup optimizer
|
|
if opt.optim=='adam':
|
|
#optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
|
|
optimizer = optim.Adam(filtered_parameters)
|
|
else:
|
|
optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
|
|
print("Optimizer:")
|
|
print(optimizer)
|
|
|
|
""" final options """
|
|
# print(opt)
|
|
with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file:
|
|
opt_log = '------------ Options -------------\n'
|
|
args = vars(opt)
|
|
for k, v in args.items():
|
|
opt_log += f'{str(k)}: {str(v)}\n'
|
|
opt_log += '---------------------------------------\n'
|
|
print(opt_log)
|
|
opt_file.write(opt_log)
|
|
|
|
""" start training """
|
|
start_iter = 0
|
|
if opt.saved_model != '':
|
|
try:
|
|
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
|
|
print(f'continue to train, start_iter: {start_iter}')
|
|
except:
|
|
pass
|
|
|
|
start_time = time.time()
|
|
best_accuracy = -1
|
|
best_norm_ED = -1
|
|
i = start_iter
|
|
|
|
scaler = GradScaler()
|
|
t1= time.time()
|
|
|
|
while(True):
|
|
# train part
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
if amp:
|
|
with autocast():
|
|
image_tensors, labels = train_dataset.get_batch()
|
|
image = image_tensors.to(device)
|
|
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
|
batch_size = image.size(0)
|
|
|
|
if 'CTC' in opt.Prediction:
|
|
preds = model(image, text).log_softmax(2)
|
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
|
preds = preds.permute(1, 0, 2)
|
|
torch.backends.cudnn.enabled = False
|
|
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
|
torch.backends.cudnn.enabled = True
|
|
else:
|
|
preds = model(image, text[:, :-1]) # align with Attention.forward
|
|
target = text[:, 1:] # without [GO] Symbol
|
|
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
|
scaler.scale(cost).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
image_tensors, labels = train_dataset.get_batch()
|
|
image = image_tensors.to(device)
|
|
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
|
batch_size = image.size(0)
|
|
if 'CTC' in opt.Prediction:
|
|
preds = model(image, text).log_softmax(2)
|
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
|
preds = preds.permute(1, 0, 2)
|
|
print(f"preds shape : {preds.shape}")
|
|
torch.backends.cudnn.enabled = False
|
|
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
|
torch.backends.cudnn.enabled = True
|
|
else:
|
|
preds = model(image, text[:, :-1]) # align with Attention.forward
|
|
target = text[:, 1:] # without [GO] Symbol
|
|
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
|
cost.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
|
optimizer.step()
|
|
loss_avg.add(cost)
|
|
|
|
# validation part
|
|
if (i % opt.valInterval == 0) and (i!=0):
|
|
print('training time: ', time.time()-t1)
|
|
t1=time.time()
|
|
elapsed_time = time.time() - start_time
|
|
# for log
|
|
with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log:
|
|
model.eval()
|
|
with torch.no_grad():
|
|
valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
|
|
infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
|
|
model.train()
|
|
|
|
# training loss and validation loss
|
|
loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
|
|
loss_avg.reset()
|
|
|
|
current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'
|
|
|
|
# keep best accuracy model (on valid dataset)
|
|
if current_accuracy > best_accuracy:
|
|
best_accuracy = current_accuracy
|
|
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
|
|
if current_norm_ED > best_norm_ED:
|
|
best_norm_ED = current_norm_ED
|
|
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
|
|
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'
|
|
|
|
loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
|
|
print(loss_model_log)
|
|
log.write(loss_model_log + '\n')
|
|
|
|
# show some predicted results
|
|
dashed_line = '-' * 80
|
|
head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
|
|
predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
|
|
|
|
#show_number = min(show_number, len(labels))
|
|
|
|
start = random.randint(0,len(labels) - show_number )
|
|
print(f"start index for showing results: {start}")
|
|
print(f"labels length: {len(labels)}")
|
|
print(f"labels : {labels}")
|
|
for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
|
|
if 'Attn' in opt.Prediction:
|
|
gt = gt[:gt.find('[s]')]
|
|
pred = pred[:pred.find('[s]')]
|
|
|
|
predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
|
|
predicted_result_log += f'{dashed_line}'
|
|
print(predicted_result_log)
|
|
log.write(predicted_result_log + '\n')
|
|
print('validation time: ', time.time()-t1)
|
|
t1=time.time()
|
|
# save model per 1e+4 iter.
|
|
if (i + 1) % 1e+4 == 0:
|
|
torch.save(
|
|
model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
|
|
|
|
if i == opt.num_iter:
|
|
print('end the training')
|
|
sys.exit()
|
|
i += 1
|