Files
easyocr/trainer/craft/train.py

480 lines
17 KiB
Python
Raw Permalink Normal View History

2025-07-10 19:42:57 +08:00
# -*- coding: utf-8 -*-
import argparse
import os
import shutil
import time
import multiprocessing as mp
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from config.load_config import load_yaml, DotDict
from data.dataset import SynthTextDataSet, CustomDataset
from loss.mseloss import Maploss_v2, Maploss_v3
from model.craft import CRAFT
from eval import main_eval
from metrics.eval_det_iou import DetectionIoUEvaluator
from utils.util import copyStateDict, save_parser
class Trainer(object):
def __init__(self, config, gpu, mode):
self.config = config
self.gpu = gpu
self.mode = mode
self.net_param = self.get_load_param(gpu)
def get_synth_loader(self):
dataset = SynthTextDataSet(
output_size=self.config.train.data.output_size,
data_dir=self.config.train.synth_data_dir,
saved_gt_dir=None,
mean=self.config.train.data.mean,
variance=self.config.train.data.variance,
gauss_init_size=self.config.train.data.gauss_init_size,
gauss_sigma=self.config.train.data.gauss_sigma,
enlarge_region=self.config.train.data.enlarge_region,
enlarge_affinity=self.config.train.data.enlarge_affinity,
aug=self.config.train.data.syn_aug,
vis_test_dir=self.config.vis_test_dir,
vis_opt=self.config.train.data.vis_opt,
sample=self.config.train.data.syn_sample,
)
syn_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
shuffle=False,
num_workers=self.config.train.num_workers,
drop_last=True,
pin_memory=True,
)
return syn_loader
def get_custom_dataset(self):
custom_dataset = CustomDataset(
output_size=self.config.train.data.output_size,
data_dir=self.config.data_root_dir,
saved_gt_dir=None,
mean=self.config.train.data.mean,
variance=self.config.train.data.variance,
gauss_init_size=self.config.train.data.gauss_init_size,
gauss_sigma=self.config.train.data.gauss_sigma,
enlarge_region=self.config.train.data.enlarge_region,
enlarge_affinity=self.config.train.data.enlarge_affinity,
watershed_param=self.config.train.data.watershed,
aug=self.config.train.data.custom_aug,
vis_test_dir=self.config.vis_test_dir,
sample=self.config.train.data.custom_sample,
vis_opt=self.config.train.data.vis_opt,
pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
do_not_care_label=self.config.train.data.do_not_care_label,
)
return custom_dataset
def get_load_param(self, gpu):
if self.config.train.ckpt_path is not None:
map_location = "cuda:%d" % gpu
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
else:
param = None
return param
def adjust_learning_rate(self, optimizer, gamma, step, lr):
lr = lr * (gamma ** step)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return param_group["lr"]
def get_loss(self):
if self.config.train.loss == 2:
criterion = Maploss_v2()
elif self.config.train.loss == 3:
criterion = Maploss_v3()
else:
raise Exception("Undefined loss")
return criterion
def iou_eval(self, dataset, train_step, buffer, model):
test_config = DotDict(self.config.test[dataset])
val_result_dir = os.path.join(
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
)
evaluator = DetectionIoUEvaluator()
metrics = main_eval(
None,
self.config.train.backbone,
test_config,
evaluator,
val_result_dir,
buffer,
model,
self.mode,
)
if self.gpu == 0 and self.config.wandb_opt:
wandb.log(
{
"{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
"{} iou Precision".format(dataset): np.round(
metrics["precision"], 3
),
"{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
}
)
def train(self, buffer_dict):
torch.cuda.set_device(self.gpu)
# MODEL -------------------------------------------------------------------------------------------------------#
# SUPERVISION model
if self.config.mode == "weak_supervision":
if self.config.train.backbone == "vgg":
supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
else:
raise Exception("Undefined architecture")
supervision_device = self.gpu
if self.config.train.ckpt_path is not None:
supervision_param = self.get_load_param(supervision_device)
supervision_model.load_state_dict(
copyStateDict(supervision_param["craft"])
)
supervision_model = supervision_model.to(f"cuda:{supervision_device}")
print(f"Supervision model loading on : gpu {supervision_device}")
else:
supervision_model, supervision_device = None, None
# TRAIN model
if self.config.train.backbone == "vgg":
craft = CRAFT(pretrained=False, amp=self.config.train.amp)
else:
raise Exception("Undefined architecture")
if self.config.train.ckpt_path is not None:
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
craft = craft.cuda()
craft = torch.nn.DataParallel(craft)
torch.backends.cudnn.benchmark = True
# DATASET -----------------------------------------------------------------------------------------------------#
if self.config.train.use_synthtext:
trn_syn_loader = self.get_synth_loader()
batch_syn = iter(trn_syn_loader)
if self.config.train.real_dataset == "custom":
trn_real_dataset = self.get_custom_dataset()
else:
raise Exception("Undefined dataset")
if self.config.mode == "weak_supervision":
trn_real_dataset.update_model(supervision_model)
trn_real_dataset.update_device(supervision_device)
trn_real_loader = torch.utils.data.DataLoader(
trn_real_dataset,
batch_size=self.config.train.batch_size,
shuffle=False,
num_workers=self.config.train.num_workers,
drop_last=False,
pin_memory=True,
)
# OPTIMIZER ---------------------------------------------------------------------------------------------------#
optimizer = optim.Adam(
craft.parameters(),
lr=self.config.train.lr,
weight_decay=self.config.train.weight_decay,
)
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
# LOSS --------------------------------------------------------------------------------------------------------#
# mixed precision
if self.config.train.amp:
scaler = torch.cuda.amp.GradScaler()
if (
self.config.train.ckpt_path is not None
and self.config.train.st_iter != 0
):
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
else:
scaler = None
criterion = self.get_loss()
# TRAIN -------------------------------------------------------------------------------------------------------#
train_step = self.config.train.st_iter
whole_training_step = self.config.train.end_iter
update_lr_rate_step = 0
training_lr = self.config.train.lr
loss_value = 0
batch_time = 0
start_time = time.time()
print(
"================================ Train start ================================"
)
while train_step < whole_training_step:
for (
index,
(
images,
region_scores,
affinity_scores,
confidence_masks,
),
) in enumerate(trn_real_loader):
craft.train()
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
update_lr_rate_step += 1
training_lr = self.adjust_learning_rate(
optimizer,
self.config.train.gamma,
update_lr_rate_step,
self.config.train.lr,
)
images = images.cuda(non_blocking=True)
region_scores = region_scores.cuda(non_blocking=True)
affinity_scores = affinity_scores.cuda(non_blocking=True)
confidence_masks = confidence_masks.cuda(non_blocking=True)
if self.config.train.use_synthtext:
# Synth image load
syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
batch_syn
)
syn_image = syn_image.cuda(non_blocking=True)
syn_region_label = syn_region_label.cuda(non_blocking=True)
syn_affi_label = syn_affi_label.cuda(non_blocking=True)
syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
# concat syn & custom image
images = torch.cat((syn_image, images), 0)
region_image_label = torch.cat(
(syn_region_label, region_scores), 0
)
affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
confidence_mask_label = torch.cat(
(syn_confidence_mask, confidence_masks), 0
)
else:
region_image_label = region_scores
affinity_image_label = affinity_scores
confidence_mask_label = confidence_masks
if self.config.train.amp:
with torch.cuda.amp.autocast():
output, _ = craft(images)
out1 = output[:, :, :, 0]
out2 = output[:, :, :, 1]
loss = criterion(
region_image_label,
affinity_image_label,
out1,
out2,
confidence_mask_label,
self.config.train.neg_rto,
self.config.train.n_min_neg,
)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
output, _ = craft(images)
out1 = output[:, :, :, 0]
out2 = output[:, :, :, 1]
loss = criterion(
region_image_label,
affinity_image_label,
out1,
out2,
confidence_mask_label,
self.config.train.neg_rto,
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
end_time = time.time()
loss_value += loss.item()
batch_time += end_time - start_time
if train_step > 0 and train_step % 5 == 0:
mean_loss = loss_value / 5
loss_value = 0
avg_batch_time = batch_time / 5
batch_time = 0
print(
"{}, training_step: {}|{}, learning rate: {:.8f}, "
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
time.strftime(
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
),
train_step,
whole_training_step,
training_lr,
mean_loss,
avg_batch_time,
)
)
if self.config.wandb_opt:
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
if (
train_step % self.config.train.eval_interval == 0
and train_step != 0
):
craft.eval()
print("Saving state, index:", train_step)
save_param_dic = {
"iter": train_step,
"craft": craft.state_dict(),
"optimizer": optimizer.state_dict(),
}
save_param_path = (
self.config.results_dir
+ "/CRAFT_clr_"
+ repr(train_step)
+ ".pth"
)
if self.config.train.amp:
save_param_dic["scaler"] = scaler.state_dict()
save_param_path = (
self.config.results_dir
+ "/CRAFT_clr_amp_"
+ repr(train_step)
+ ".pth"
)
torch.save(save_param_dic, save_param_path)
# validation
self.iou_eval(
"custom_data",
train_step,
buffer_dict["custom_data"],
craft,
)
train_step += 1
if train_step >= whole_training_step:
break
if self.config.mode == "weak_supervision":
state_dict = craft.module.state_dict()
supervision_model.load_state_dict(state_dict)
trn_real_dataset.update_model(supervision_model)
# save last model
save_param_dic = {
"iter": train_step,
"craft": craft.state_dict(),
"optimizer": optimizer.state_dict(),
}
save_param_path = (
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
)
if self.config.train.amp:
save_param_dic["scaler"] = scaler.state_dict()
save_param_path = (
self.config.results_dir
+ "/CRAFT_clr_amp_"
+ repr(train_step)
+ ".pth"
)
torch.save(save_param_dic, save_param_path)
def main():
parser = argparse.ArgumentParser(description="CRAFT custom data train")
parser.add_argument(
"--yaml",
"--yaml_file_name",
default="custom_data_train",
type=str,
help="Load configuration",
)
parser.add_argument(
"--port", "--use ddp port", default="2346", type=str, help="Port number"
)
args = parser.parse_args()
# load configure
exp_name = args.yaml
config = load_yaml(args.yaml)
print("-" * 20 + " Options " + "-" * 20)
print(yaml.dump(config))
print("-" * 40)
# Make result_dir
res_dir = os.path.join(config["results_dir"], args.yaml)
config["results_dir"] = res_dir
if not os.path.exists(res_dir):
os.makedirs(res_dir)
# Duplicate yaml file to result_dir
shutil.copy(
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
)
if config["mode"] == "weak_supervision":
mode = "weak_supervision"
else:
mode = None
# Apply config to wandb
if config["wandb_opt"]:
wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
wandb.config.update(config)
config = DotDict(config)
# Start train
buffer_dict = {"custom_data":None}
trainer = Trainer(config, 0, mode)
trainer.train(buffer_dict)
if config["wandb_opt"]:
wandb.finish()
if __name__ == "__main__":
main()