import torch import argparse from model import Model import os import torch.backends.cudnn as cudnn import yaml from utils import AttrDict import pandas as pd from utils import CTCLabelConverter, AttnLabelConverter, Averager cudnn.benchmark = True cudnn.deterministic = False def get_config(file_path): with open(file_path, 'r', encoding="utf8") as stream: opt = yaml.safe_load(stream) opt = AttrDict(opt) if opt.lang_char == 'None' and opt.symbol=='None': opt.character = opt.number elif opt.lang_char == 'None': characters = '' for data in opt['select_data'].split('-'): csv_path = os.path.join(opt['train_data'], data, 'labels.csv') df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False) all_char = ''.join(df['words']) characters += ''.join(set(all_char)) characters = sorted(set(characters)) opt.character= ''.join(characters) else: opt.character = opt.number + opt.symbol + opt.lang_char os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) print(f"converter.character : {converter.character}") print(f"字符集: {opt.character}") print(f"字符集长度: {opt.num_class}") os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) return opt def parse_args(): parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具') parser.add_argument('--input', type=str, default='digit_cnn.pth', help='输入PyTorch模型路径 (默认: digit_cnn.pth)') parser.add_argument('--output', type=str, default='digit_cnn.onnx', help='输出ONNX模型路径 (默认: digit_cnn.onnx)') parser.add_argument('--opset', type=int, default=11, help='ONNX算子集版本 (默认: 11)') return parser.parse_args() def convert_to_onnx(input_path, output_path, opset_version,opt): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if opt.rgb: opt.input_channel = 3 model = Model(opt).to(device) torch.load(input_path) model.eval() dummy_input = torch.randn(1, 3, 32, 128).to(device) torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=opset_version, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } ) print(f"模型已成功转换为 {output_path} (opset {opset_version})") if __name__ == '__main__': args = parse_args() opt = get_config("config_files/4digit_config.yaml") convert_to_onnx(args.input, args.output, args.opset,opt)