Files
easyocr/trainer/export_onnx.py

88 lines
3.0 KiB
Python

import torch
import argparse
from model import Model
import os
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd
from utils import CTCLabelConverter, AttnLabelConverter, Averager
cudnn.benchmark = True
cudnn.deterministic = False
def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream)
opt = AttrDict(opt)
if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number
elif opt.lang_char == 'None':
characters = ''
for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words'])
characters += ''.join(set(all_char))
characters = sorted(set(characters))
opt.character= ''.join(characters)
else:
opt.character = opt.number + opt.symbol + opt.lang_char
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
print(f"converter.character : {converter.character}")
print(f"字符集: {opt.character}")
print(f"字符集长度: {opt.num_class}")
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
return opt
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
parser.add_argument('--input', type=str, default='digit_cnn.pth',
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
parser.add_argument('--opset', type=int, default=11,
help='ONNX算子集版本 (默认: 11)')
return parser.parse_args()
def convert_to_onnx(input_path, output_path, opset_version,opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.rgb:
opt.input_channel = 3
model = Model(opt).to(device)
torch.load(input_path)
model.eval()
dummy_input = torch.randn(1, 3, 32, 128).to(device)
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
if __name__ == '__main__':
args = parse_args()
opt = get_config("config_files/4digit_config.yaml")
convert_to_onnx(args.input, args.output, args.opset,opt)