add .gitignore in the alldata directory

This commit is contained in:
2025-07-21 18:37:41 +08:00
parent c10b0719c7
commit 5343c29e1b
16 changed files with 9860 additions and 4 deletions

87
trainer/export_onnx.py Normal file
View File

@@ -0,0 +1,87 @@
import torch
import argparse
from model import Model
import os
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd
from utils import CTCLabelConverter, AttnLabelConverter, Averager
cudnn.benchmark = True
cudnn.deterministic = False
def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream)
opt = AttrDict(opt)
if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number
elif opt.lang_char == 'None':
characters = ''
for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words'])
characters += ''.join(set(all_char))
characters = sorted(set(characters))
opt.character= ''.join(characters)
else:
opt.character = opt.number + opt.symbol + opt.lang_char
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
print(f"converter.character : {converter.character}")
print(f"字符集: {opt.character}")
print(f"字符集长度: {opt.num_class}")
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
return opt
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
parser.add_argument('--input', type=str, default='digit_cnn.pth',
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
parser.add_argument('--opset', type=int, default=11,
help='ONNX算子集版本 (默认: 11)')
return parser.parse_args()
def convert_to_onnx(input_path, output_path, opset_version,opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.rgb:
opt.input_channel = 3
model = Model(opt).to(device)
torch.load(input_path)
model.eval()
dummy_input = torch.randn(1, 3, 32, 128).to(device)
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
if __name__ == '__main__':
args = parse_args()
opt = get_config("config_files/4digit_config.yaml")
convert_to_onnx(args.input, args.output, args.opset,opt)