diff --git a/trainer/all_data/generate_digits_random_fs_bg_fg.py b/trainer/all_data/generate_digits_random_fs_bg_fg.py index a7eabcd..5f9fd02 100644 --- a/trainer/all_data/generate_digits_random_fs_bg_fg.py +++ b/trainer/all_data/generate_digits_random_fs_bg_fg.py @@ -1,86 +1,86 @@ - -import cv2 -import numpy as np -import os -import random -from PIL import Image, ImageDraw, ImageFont - -X_RAND_VALUE = 2 -Y_RAND_VALUE = 1 -ROTATE_ANGLE = 3 - - -BG_COLORS = [ - (33, 40, 45), (36, 51, 62), (35, 37, 154), - (0, 38, 202), (239, 255, 255), (241, 255, 255) -] - -DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)] - -def generate_4digit_image(): - bg_color = random.choice(BG_COLORS) - font_size = random.randint(24, 30) - font = ImageFont.truetype("arial.ttf", font_size) - - # 扩大画布尺寸(50x160)提供足够缓冲空间 - canvas = np.zeros((50, 160, 3), dtype=np.uint8) - canvas[:,:] = bg_color - pil_img = Image.fromarray(canvas) - draw = ImageDraw.Draw(pil_img) - - digits = [] - for i in range(4): - digit = str(random.randint(0, 9)) - digits.append(digit) - x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE) - y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE) - digit_color = random.choice(DIGIT_COLORS) - # 调整数字绘制位置到画布中心区域 - draw.text((20+i*32+x_offset, 12+y_offset), digit, - font=font, fill=digit_color) - - angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE) - rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color) - # 安全裁剪区域(从扩大后的画布中心裁剪) - rotated = rotated.crop((20, 10, 148, 42)) - - return np.array(rotated), ''.join(digits) - -def generate_train_dataset(num_samples=1000): - os.makedirs('4digit_train', exist_ok=True) - with open('4digit_train/labels.csv', 'w') as f: - f.write(f"filename,words\n") - for i in range(num_samples): - img, label = generate_4digit_image() - # print(f"type of label : {type(label)}") - label = str(label).zfill(4) - img_path = f'4digit_train/{i:04d}.jpg' - cv2.imwrite(img_path, img) - f.write(f"{i:04d}.jpg,{label}\n") - -def generate_valid_dataset(num_samples=200): - os.makedirs('4digit_valid', exist_ok=True) - with open('4digit_valid/labels.csv', 'w') as f: - f.write(f"filename,words\n") - for i in range(num_samples): - img, label = generate_4digit_image() - label = str(label).zfill(4) - - img_path = f'4digit_valid/{i:04d}.jpg' - cv2.imwrite(img_path, img) - f.write(f"{i:04d}.jpg,{label}\n") - -def generate_eval_dataset(num_samples=200): - os.makedirs('4digit_eval', exist_ok=True) - with open('4digit_eval/labels.csv', 'w') as f: - f.write(f"filename,words\n") - for i in range(num_samples): - img, label = generate_4digit_image() - label = str(label).zfill(4) - img_path = f'4digit_eval/{i:04d}.jpg' - cv2.imwrite(img_path, img) - f.write(f"{i:04d}.jpg,{label}\n") -if __name__ == "__main__": - generate_train_dataset() - generate_eval_dataset() + +import cv2 +import numpy as np +import os +import random +from PIL import Image, ImageDraw, ImageFont + +X_RAND_VALUE = 2 +Y_RAND_VALUE = 1 +ROTATE_ANGLE = 3 + + +BG_COLORS = [ + (33, 40, 45), (36, 51, 62), (35, 37, 154), + (0, 38, 202), (239, 255, 255), (241, 255, 255) +] + +DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)] + +def generate_4digit_image(): + bg_color = random.choice(BG_COLORS) + font_size = random.randint(24, 30) + font = ImageFont.truetype("arial.ttf", font_size) + + # 扩大画布尺寸(50x160)提供足够缓冲空间 + canvas = np.zeros((50, 160, 3), dtype=np.uint8) + canvas[:,:] = bg_color + pil_img = Image.fromarray(canvas) + draw = ImageDraw.Draw(pil_img) + + digits = [] + for i in range(4): + digit = str(random.randint(0, 9)) + digits.append(digit) + x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE) + y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE) + digit_color = random.choice(DIGIT_COLORS) + # 调整数字绘制位置到画布中心区域 + draw.text((20+i*32+x_offset, 12+y_offset), digit, + font=font, fill=digit_color) + + angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE) + rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color) + # 安全裁剪区域(从扩大后的画布中心裁剪) + rotated = rotated.crop((20, 10, 148, 42)) + + return np.array(rotated), ''.join(digits) + +def generate_train_dataset(num_samples=1000): + os.makedirs('4digit_train', exist_ok=True) + with open('4digit_train/labels.csv', 'w') as f: + f.write(f"filename,words\n") + for i in range(num_samples): + img, label = generate_4digit_image() + # print(f"type of label : {type(label)}") + label = str(label).zfill(4) + img_path = f'4digit_train/{i:04d}.jpg' + cv2.imwrite(img_path, img) + f.write(f"{i:04d}.jpg,{label}\n") + +def generate_valid_dataset(num_samples=200): + os.makedirs('4digit_valid', exist_ok=True) + with open('4digit_valid/labels.csv', 'w') as f: + f.write(f"filename,words\n") + for i in range(num_samples): + img, label = generate_4digit_image() + label = str(label).zfill(4) + + img_path = f'4digit_valid/{i:04d}.jpg' + cv2.imwrite(img_path, img) + f.write(f"{i:04d}.jpg,{label}\n") + +def generate_eval_dataset(num_samples=200): + os.makedirs('4digit_eval', exist_ok=True) + with open('4digit_eval/labels.csv', 'w') as f: + f.write(f"filename,words\n") + for i in range(num_samples): + img, label = generate_4digit_image() + label = str(label).zfill(4) + img_path = f'4digit_eval/{i:04d}.jpg' + cv2.imwrite(img_path, img) + f.write(f"{i:04d}.jpg,{label}\n") +if __name__ == "__main__": + generate_train_dataset() + generate_eval_dataset() generate_valid_dataset() \ No newline at end of file diff --git a/trainer/export_onnx.py b/trainer/export_onnx.py index 5d969c7..58ce38c 100644 --- a/trainer/export_onnx.py +++ b/trainer/export_onnx.py @@ -1,87 +1,87 @@ - -import torch -import argparse -from model import Model -import os -import torch.backends.cudnn as cudnn -import yaml -from utils import AttrDict -import pandas as pd -from utils import CTCLabelConverter, AttnLabelConverter, Averager - - -cudnn.benchmark = True -cudnn.deterministic = False - -def get_config(file_path): - with open(file_path, 'r', encoding="utf8") as stream: - opt = yaml.safe_load(stream) - opt = AttrDict(opt) - - if opt.lang_char == 'None' and opt.symbol=='None': - opt.character = opt.number - elif opt.lang_char == 'None': - characters = '' - for data in opt['select_data'].split('-'): - csv_path = os.path.join(opt['train_data'], data, 'labels.csv') - df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False) - all_char = ''.join(df['words']) - characters += ''.join(set(all_char)) - characters = sorted(set(characters)) - opt.character= ''.join(characters) - else: - opt.character = opt.number + opt.symbol + opt.lang_char - os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) - if 'CTC' in opt.Prediction: - converter = CTCLabelConverter(opt.character) - else: - converter = AttnLabelConverter(opt.character) - opt.num_class = len(converter.character) - print(f"converter.character : {converter.character}") - print(f"字符集: {opt.character}") - print(f"字符集长度: {opt.num_class}") - os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) - return opt - -def parse_args(): - parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具') - parser.add_argument('--input', type=str, default='digit_cnn.pth', - help='输入PyTorch模型路径 (默认: digit_cnn.pth)') - parser.add_argument('--output', type=str, default='digit_cnn.onnx', - help='输出ONNX模型路径 (默认: digit_cnn.onnx)') - parser.add_argument('--opset', type=int, default=11, - help='ONNX算子集版本 (默认: 11)') - return parser.parse_args() - -def convert_to_onnx(input_path, output_path, opset_version,opt): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - if opt.rgb: - opt.input_channel = 3 - model = Model(opt).to(device) - torch.load(input_path) - model.eval() - - dummy_input = torch.randn(1, 3, 32, 128).to(device) - - torch.onnx.export( - model, - dummy_input, - output_path, - export_params=True, - opset_version=opset_version, - do_constant_folding=True, - input_names=['input'], - output_names=['output'], - dynamic_axes={ - 'input': {0: 'batch_size'}, - 'output': {0: 'batch_size'} - } - ) - print(f"模型已成功转换为 {output_path} (opset {opset_version})") - -if __name__ == '__main__': - args = parse_args() - opt = get_config("config_files/4digit_config.yaml") - - convert_to_onnx(args.input, args.output, args.opset,opt) + +import torch +import argparse +from model import Model +import os +import torch.backends.cudnn as cudnn +import yaml +from utils import AttrDict +import pandas as pd +from utils import CTCLabelConverter, AttnLabelConverter, Averager + + +cudnn.benchmark = True +cudnn.deterministic = False + +def get_config(file_path): + with open(file_path, 'r', encoding="utf8") as stream: + opt = yaml.safe_load(stream) + opt = AttrDict(opt) + + if opt.lang_char == 'None' and opt.symbol=='None': + opt.character = opt.number + elif opt.lang_char == 'None': + characters = '' + for data in opt['select_data'].split('-'): + csv_path = os.path.join(opt['train_data'], data, 'labels.csv') + df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False) + all_char = ''.join(df['words']) + characters += ''.join(set(all_char)) + characters = sorted(set(characters)) + opt.character= ''.join(characters) + else: + opt.character = opt.number + opt.symbol + opt.lang_char + os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) + if 'CTC' in opt.Prediction: + converter = CTCLabelConverter(opt.character) + else: + converter = AttnLabelConverter(opt.character) + opt.num_class = len(converter.character) + print(f"converter.character : {converter.character}") + print(f"字符集: {opt.character}") + print(f"字符集长度: {opt.num_class}") + os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) + return opt + +def parse_args(): + parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具') + parser.add_argument('--input', type=str, default='digit_cnn.pth', + help='输入PyTorch模型路径 (默认: digit_cnn.pth)') + parser.add_argument('--output', type=str, default='digit_cnn.onnx', + help='输出ONNX模型路径 (默认: digit_cnn.onnx)') + parser.add_argument('--opset', type=int, default=11, + help='ONNX算子集版本 (默认: 11)') + return parser.parse_args() + +def convert_to_onnx(input_path, output_path, opset_version,opt): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if opt.rgb: + opt.input_channel = 3 + model = Model(opt).to(device) + torch.load(input_path) + model.eval() + + dummy_input = torch.randn(1, 3, 32, 128).to(device) + + torch.onnx.export( + model, + dummy_input, + output_path, + export_params=True, + opset_version=opset_version, + do_constant_folding=True, + input_names=['input'], + output_names=['output'], + dynamic_axes={ + 'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'} + } + ) + print(f"模型已成功转换为 {output_path} (opset {opset_version})") + +if __name__ == '__main__': + args = parse_args() + opt = get_config("config_files/4digit_config.yaml") + + convert_to_onnx(args.input, args.output, args.opset,opt) diff --git a/trainer/saved_models/.gitignore b/trainer/saved_models/.gitignore new file mode 100644 index 0000000..12355e3 --- /dev/null +++ b/trainer/saved_models/.gitignore @@ -0,0 +1 @@ +./** \ No newline at end of file