add .gitignore to saved_models

This commit is contained in:
2025-07-21 19:06:13 +08:00
parent 5343c29e1b
commit 898b1a59f3
3 changed files with 173 additions and 172 deletions

View File

@@ -1,86 +1,86 @@
import cv2
import numpy as np
import os
import random
from PIL import Image, ImageDraw, ImageFont
X_RAND_VALUE = 2
Y_RAND_VALUE = 1
ROTATE_ANGLE = 3
BG_COLORS = [
(33, 40, 45), (36, 51, 62), (35, 37, 154),
(0, 38, 202), (239, 255, 255), (241, 255, 255)
]
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
def generate_4digit_image():
bg_color = random.choice(BG_COLORS)
font_size = random.randint(24, 30)
font = ImageFont.truetype("arial.ttf", font_size)
# 扩大画布尺寸(50x160)提供足够缓冲空间
canvas = np.zeros((50, 160, 3), dtype=np.uint8)
canvas[:,:] = bg_color
pil_img = Image.fromarray(canvas)
draw = ImageDraw.Draw(pil_img)
digits = []
for i in range(4):
digit = str(random.randint(0, 9))
digits.append(digit)
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
digit_color = random.choice(DIGIT_COLORS)
# 调整数字绘制位置到画布中心区域
draw.text((20+i*32+x_offset, 12+y_offset), digit,
font=font, fill=digit_color)
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
# 安全裁剪区域(从扩大后的画布中心裁剪)
rotated = rotated.crop((20, 10, 148, 42))
return np.array(rotated), ''.join(digits)
def generate_train_dataset(num_samples=1000):
os.makedirs('4digit_train', exist_ok=True)
with open('4digit_train/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
# print(f"type of label : {type(label)}")
label = str(label).zfill(4)
img_path = f'4digit_train/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
def generate_valid_dataset(num_samples=200):
os.makedirs('4digit_valid', exist_ok=True)
with open('4digit_valid/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
label = str(label).zfill(4)
img_path = f'4digit_valid/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
def generate_eval_dataset(num_samples=200):
os.makedirs('4digit_eval', exist_ok=True)
with open('4digit_eval/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
label = str(label).zfill(4)
img_path = f'4digit_eval/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
if __name__ == "__main__":
generate_train_dataset()
generate_eval_dataset()
import cv2
import numpy as np
import os
import random
from PIL import Image, ImageDraw, ImageFont
X_RAND_VALUE = 2
Y_RAND_VALUE = 1
ROTATE_ANGLE = 3
BG_COLORS = [
(33, 40, 45), (36, 51, 62), (35, 37, 154),
(0, 38, 202), (239, 255, 255), (241, 255, 255)
]
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
def generate_4digit_image():
bg_color = random.choice(BG_COLORS)
font_size = random.randint(24, 30)
font = ImageFont.truetype("arial.ttf", font_size)
# 扩大画布尺寸(50x160)提供足够缓冲空间
canvas = np.zeros((50, 160, 3), dtype=np.uint8)
canvas[:,:] = bg_color
pil_img = Image.fromarray(canvas)
draw = ImageDraw.Draw(pil_img)
digits = []
for i in range(4):
digit = str(random.randint(0, 9))
digits.append(digit)
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
digit_color = random.choice(DIGIT_COLORS)
# 调整数字绘制位置到画布中心区域
draw.text((20+i*32+x_offset, 12+y_offset), digit,
font=font, fill=digit_color)
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
# 安全裁剪区域(从扩大后的画布中心裁剪)
rotated = rotated.crop((20, 10, 148, 42))
return np.array(rotated), ''.join(digits)
def generate_train_dataset(num_samples=1000):
os.makedirs('4digit_train', exist_ok=True)
with open('4digit_train/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
# print(f"type of label : {type(label)}")
label = str(label).zfill(4)
img_path = f'4digit_train/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
def generate_valid_dataset(num_samples=200):
os.makedirs('4digit_valid', exist_ok=True)
with open('4digit_valid/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
label = str(label).zfill(4)
img_path = f'4digit_valid/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
def generate_eval_dataset(num_samples=200):
os.makedirs('4digit_eval', exist_ok=True)
with open('4digit_eval/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
label = str(label).zfill(4)
img_path = f'4digit_eval/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
if __name__ == "__main__":
generate_train_dataset()
generate_eval_dataset()
generate_valid_dataset()

View File

@@ -1,87 +1,87 @@
import torch
import argparse
from model import Model
import os
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd
from utils import CTCLabelConverter, AttnLabelConverter, Averager
cudnn.benchmark = True
cudnn.deterministic = False
def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream)
opt = AttrDict(opt)
if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number
elif opt.lang_char == 'None':
characters = ''
for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words'])
characters += ''.join(set(all_char))
characters = sorted(set(characters))
opt.character= ''.join(characters)
else:
opt.character = opt.number + opt.symbol + opt.lang_char
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
print(f"converter.character : {converter.character}")
print(f"字符集: {opt.character}")
print(f"字符集长度: {opt.num_class}")
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
return opt
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
parser.add_argument('--input', type=str, default='digit_cnn.pth',
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
parser.add_argument('--opset', type=int, default=11,
help='ONNX算子集版本 (默认: 11)')
return parser.parse_args()
def convert_to_onnx(input_path, output_path, opset_version,opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.rgb:
opt.input_channel = 3
model = Model(opt).to(device)
torch.load(input_path)
model.eval()
dummy_input = torch.randn(1, 3, 32, 128).to(device)
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
if __name__ == '__main__':
args = parse_args()
opt = get_config("config_files/4digit_config.yaml")
convert_to_onnx(args.input, args.output, args.opset,opt)
import torch
import argparse
from model import Model
import os
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd
from utils import CTCLabelConverter, AttnLabelConverter, Averager
cudnn.benchmark = True
cudnn.deterministic = False
def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream)
opt = AttrDict(opt)
if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number
elif opt.lang_char == 'None':
characters = ''
for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words'])
characters += ''.join(set(all_char))
characters = sorted(set(characters))
opt.character= ''.join(characters)
else:
opt.character = opt.number + opt.symbol + opt.lang_char
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
print(f"converter.character : {converter.character}")
print(f"字符集: {opt.character}")
print(f"字符集长度: {opt.num_class}")
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
return opt
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
parser.add_argument('--input', type=str, default='digit_cnn.pth',
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
parser.add_argument('--opset', type=int, default=11,
help='ONNX算子集版本 (默认: 11)')
return parser.parse_args()
def convert_to_onnx(input_path, output_path, opset_version,opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.rgb:
opt.input_channel = 3
model = Model(opt).to(device)
torch.load(input_path)
model.eval()
dummy_input = torch.randn(1, 3, 32, 128).to(device)
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
if __name__ == '__main__':
args = parse_args()
opt = get_config("config_files/4digit_config.yaml")
convert_to_onnx(args.input, args.output, args.opset,opt)

1
trainer/saved_models/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
./**