add .gitignore to saved_models

This commit is contained in:
2025-07-21 19:06:13 +08:00
parent 5343c29e1b
commit 898b1a59f3
3 changed files with 173 additions and 172 deletions

View File

@@ -1,86 +1,86 @@
import cv2 import cv2
import numpy as np import numpy as np
import os import os
import random import random
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
X_RAND_VALUE = 2 X_RAND_VALUE = 2
Y_RAND_VALUE = 1 Y_RAND_VALUE = 1
ROTATE_ANGLE = 3 ROTATE_ANGLE = 3
BG_COLORS = [ BG_COLORS = [
(33, 40, 45), (36, 51, 62), (35, 37, 154), (33, 40, 45), (36, 51, 62), (35, 37, 154),
(0, 38, 202), (239, 255, 255), (241, 255, 255) (0, 38, 202), (239, 255, 255), (241, 255, 255)
] ]
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)] DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
def generate_4digit_image(): def generate_4digit_image():
bg_color = random.choice(BG_COLORS) bg_color = random.choice(BG_COLORS)
font_size = random.randint(24, 30) font_size = random.randint(24, 30)
font = ImageFont.truetype("arial.ttf", font_size) font = ImageFont.truetype("arial.ttf", font_size)
# 扩大画布尺寸(50x160)提供足够缓冲空间 # 扩大画布尺寸(50x160)提供足够缓冲空间
canvas = np.zeros((50, 160, 3), dtype=np.uint8) canvas = np.zeros((50, 160, 3), dtype=np.uint8)
canvas[:,:] = bg_color canvas[:,:] = bg_color
pil_img = Image.fromarray(canvas) pil_img = Image.fromarray(canvas)
draw = ImageDraw.Draw(pil_img) draw = ImageDraw.Draw(pil_img)
digits = [] digits = []
for i in range(4): for i in range(4):
digit = str(random.randint(0, 9)) digit = str(random.randint(0, 9))
digits.append(digit) digits.append(digit)
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE) x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE) y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
digit_color = random.choice(DIGIT_COLORS) digit_color = random.choice(DIGIT_COLORS)
# 调整数字绘制位置到画布中心区域 # 调整数字绘制位置到画布中心区域
draw.text((20+i*32+x_offset, 12+y_offset), digit, draw.text((20+i*32+x_offset, 12+y_offset), digit,
font=font, fill=digit_color) font=font, fill=digit_color)
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE) angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color) rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
# 安全裁剪区域(从扩大后的画布中心裁剪) # 安全裁剪区域(从扩大后的画布中心裁剪)
rotated = rotated.crop((20, 10, 148, 42)) rotated = rotated.crop((20, 10, 148, 42))
return np.array(rotated), ''.join(digits) return np.array(rotated), ''.join(digits)
def generate_train_dataset(num_samples=1000): def generate_train_dataset(num_samples=1000):
os.makedirs('4digit_train', exist_ok=True) os.makedirs('4digit_train', exist_ok=True)
with open('4digit_train/labels.csv', 'w') as f: with open('4digit_train/labels.csv', 'w') as f:
f.write(f"filename,words\n") f.write(f"filename,words\n")
for i in range(num_samples): for i in range(num_samples):
img, label = generate_4digit_image() img, label = generate_4digit_image()
# print(f"type of label : {type(label)}") # print(f"type of label : {type(label)}")
label = str(label).zfill(4) label = str(label).zfill(4)
img_path = f'4digit_train/{i:04d}.jpg' img_path = f'4digit_train/{i:04d}.jpg'
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n") f.write(f"{i:04d}.jpg,{label}\n")
def generate_valid_dataset(num_samples=200): def generate_valid_dataset(num_samples=200):
os.makedirs('4digit_valid', exist_ok=True) os.makedirs('4digit_valid', exist_ok=True)
with open('4digit_valid/labels.csv', 'w') as f: with open('4digit_valid/labels.csv', 'w') as f:
f.write(f"filename,words\n") f.write(f"filename,words\n")
for i in range(num_samples): for i in range(num_samples):
img, label = generate_4digit_image() img, label = generate_4digit_image()
label = str(label).zfill(4) label = str(label).zfill(4)
img_path = f'4digit_valid/{i:04d}.jpg' img_path = f'4digit_valid/{i:04d}.jpg'
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n") f.write(f"{i:04d}.jpg,{label}\n")
def generate_eval_dataset(num_samples=200): def generate_eval_dataset(num_samples=200):
os.makedirs('4digit_eval', exist_ok=True) os.makedirs('4digit_eval', exist_ok=True)
with open('4digit_eval/labels.csv', 'w') as f: with open('4digit_eval/labels.csv', 'w') as f:
f.write(f"filename,words\n") f.write(f"filename,words\n")
for i in range(num_samples): for i in range(num_samples):
img, label = generate_4digit_image() img, label = generate_4digit_image()
label = str(label).zfill(4) label = str(label).zfill(4)
img_path = f'4digit_eval/{i:04d}.jpg' img_path = f'4digit_eval/{i:04d}.jpg'
cv2.imwrite(img_path, img) cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n") f.write(f"{i:04d}.jpg,{label}\n")
if __name__ == "__main__": if __name__ == "__main__":
generate_train_dataset() generate_train_dataset()
generate_eval_dataset() generate_eval_dataset()
generate_valid_dataset() generate_valid_dataset()

View File

@@ -1,87 +1,87 @@
import torch import torch
import argparse import argparse
from model import Model from model import Model
import os import os
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import yaml import yaml
from utils import AttrDict from utils import AttrDict
import pandas as pd import pandas as pd
from utils import CTCLabelConverter, AttnLabelConverter, Averager from utils import CTCLabelConverter, AttnLabelConverter, Averager
cudnn.benchmark = True cudnn.benchmark = True
cudnn.deterministic = False cudnn.deterministic = False
def get_config(file_path): def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream: with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream) opt = yaml.safe_load(stream)
opt = AttrDict(opt) opt = AttrDict(opt)
if opt.lang_char == 'None' and opt.symbol=='None': if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number opt.character = opt.number
elif opt.lang_char == 'None': elif opt.lang_char == 'None':
characters = '' characters = ''
for data in opt['select_data'].split('-'): for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv') csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False) df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words']) all_char = ''.join(df['words'])
characters += ''.join(set(all_char)) characters += ''.join(set(all_char))
characters = sorted(set(characters)) characters = sorted(set(characters))
opt.character= ''.join(characters) opt.character= ''.join(characters)
else: else:
opt.character = opt.number + opt.symbol + opt.lang_char opt.character = opt.number + opt.symbol + opt.lang_char
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
if 'CTC' in opt.Prediction: if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character) converter = CTCLabelConverter(opt.character)
else: else:
converter = AttnLabelConverter(opt.character) converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character) opt.num_class = len(converter.character)
print(f"converter.character : {converter.character}") print(f"converter.character : {converter.character}")
print(f"字符集: {opt.character}") print(f"字符集: {opt.character}")
print(f"字符集长度: {opt.num_class}") print(f"字符集长度: {opt.num_class}")
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
return opt return opt
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具') parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
parser.add_argument('--input', type=str, default='digit_cnn.pth', parser.add_argument('--input', type=str, default='digit_cnn.pth',
help='输入PyTorch模型路径 (默认: digit_cnn.pth)') help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
parser.add_argument('--output', type=str, default='digit_cnn.onnx', parser.add_argument('--output', type=str, default='digit_cnn.onnx',
help='输出ONNX模型路径 (默认: digit_cnn.onnx)') help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
parser.add_argument('--opset', type=int, default=11, parser.add_argument('--opset', type=int, default=11,
help='ONNX算子集版本 (默认: 11)') help='ONNX算子集版本 (默认: 11)')
return parser.parse_args() return parser.parse_args()
def convert_to_onnx(input_path, output_path, opset_version,opt): def convert_to_onnx(input_path, output_path, opset_version,opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.rgb: if opt.rgb:
opt.input_channel = 3 opt.input_channel = 3
model = Model(opt).to(device) model = Model(opt).to(device)
torch.load(input_path) torch.load(input_path)
model.eval() model.eval()
dummy_input = torch.randn(1, 3, 32, 128).to(device) dummy_input = torch.randn(1, 3, 32, 128).to(device)
torch.onnx.export( torch.onnx.export(
model, model,
dummy_input, dummy_input,
output_path, output_path,
export_params=True, export_params=True,
opset_version=opset_version, opset_version=opset_version,
do_constant_folding=True, do_constant_folding=True,
input_names=['input'], input_names=['input'],
output_names=['output'], output_names=['output'],
dynamic_axes={ dynamic_axes={
'input': {0: 'batch_size'}, 'input': {0: 'batch_size'},
'output': {0: 'batch_size'} 'output': {0: 'batch_size'}
} }
) )
print(f"模型已成功转换为 {output_path} (opset {opset_version})") print(f"模型已成功转换为 {output_path} (opset {opset_version})")
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
opt = get_config("config_files/4digit_config.yaml") opt = get_config("config_files/4digit_config.yaml")
convert_to_onnx(args.input, args.output, args.opset,opt) convert_to_onnx(args.input, args.output, args.opset,opt)

1
trainer/saved_models/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
./**