88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
|
|
import torch
|
|
import argparse
|
|
from model import Model
|
|
import os
|
|
import torch.backends.cudnn as cudnn
|
|
import yaml
|
|
from utils import AttrDict
|
|
import pandas as pd
|
|
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
|
|
|
|
|
cudnn.benchmark = True
|
|
cudnn.deterministic = False
|
|
|
|
def get_config(file_path):
|
|
with open(file_path, 'r', encoding="utf8") as stream:
|
|
opt = yaml.safe_load(stream)
|
|
opt = AttrDict(opt)
|
|
|
|
if opt.lang_char == 'None' and opt.symbol=='None':
|
|
opt.character = opt.number
|
|
elif opt.lang_char == 'None':
|
|
characters = ''
|
|
for data in opt['select_data'].split('-'):
|
|
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
|
|
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
|
|
all_char = ''.join(df['words'])
|
|
characters += ''.join(set(all_char))
|
|
characters = sorted(set(characters))
|
|
opt.character= ''.join(characters)
|
|
else:
|
|
opt.character = opt.number + opt.symbol + opt.lang_char
|
|
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
|
if 'CTC' in opt.Prediction:
|
|
converter = CTCLabelConverter(opt.character)
|
|
else:
|
|
converter = AttnLabelConverter(opt.character)
|
|
opt.num_class = len(converter.character)
|
|
print(f"converter.character : {converter.character}")
|
|
print(f"字符集: {opt.character}")
|
|
print(f"字符集长度: {opt.num_class}")
|
|
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
|
return opt
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
|
|
parser.add_argument('--input', type=str, default='digit_cnn.pth',
|
|
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
|
|
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
|
|
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
|
|
parser.add_argument('--opset', type=int, default=11,
|
|
help='ONNX算子集版本 (默认: 11)')
|
|
return parser.parse_args()
|
|
|
|
def convert_to_onnx(input_path, output_path, opset_version,opt):
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
if opt.rgb:
|
|
opt.input_channel = 3
|
|
model = Model(opt).to(device)
|
|
torch.load(input_path)
|
|
model.eval()
|
|
|
|
dummy_input = torch.randn(1, 3, 32, 128).to(device)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
dummy_input,
|
|
output_path,
|
|
export_params=True,
|
|
opset_version=opset_version,
|
|
do_constant_folding=True,
|
|
input_names=['input'],
|
|
output_names=['output'],
|
|
dynamic_axes={
|
|
'input': {0: 'batch_size'},
|
|
'output': {0: 'batch_size'}
|
|
}
|
|
)
|
|
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
opt = get_config("config_files/4digit_config.yaml")
|
|
|
|
convert_to_onnx(args.input, args.output, args.opset,opt)
|