add .gitignore to saved_models
This commit is contained in:
@@ -1,86 +1,86 @@
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
X_RAND_VALUE = 2
|
||||
Y_RAND_VALUE = 1
|
||||
ROTATE_ANGLE = 3
|
||||
|
||||
|
||||
BG_COLORS = [
|
||||
(33, 40, 45), (36, 51, 62), (35, 37, 154),
|
||||
(0, 38, 202), (239, 255, 255), (241, 255, 255)
|
||||
]
|
||||
|
||||
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
|
||||
|
||||
def generate_4digit_image():
|
||||
bg_color = random.choice(BG_COLORS)
|
||||
font_size = random.randint(24, 30)
|
||||
font = ImageFont.truetype("arial.ttf", font_size)
|
||||
|
||||
# 扩大画布尺寸(50x160)提供足够缓冲空间
|
||||
canvas = np.zeros((50, 160, 3), dtype=np.uint8)
|
||||
canvas[:,:] = bg_color
|
||||
pil_img = Image.fromarray(canvas)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
digits = []
|
||||
for i in range(4):
|
||||
digit = str(random.randint(0, 9))
|
||||
digits.append(digit)
|
||||
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
|
||||
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
|
||||
digit_color = random.choice(DIGIT_COLORS)
|
||||
# 调整数字绘制位置到画布中心区域
|
||||
draw.text((20+i*32+x_offset, 12+y_offset), digit,
|
||||
font=font, fill=digit_color)
|
||||
|
||||
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
|
||||
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
|
||||
# 安全裁剪区域(从扩大后的画布中心裁剪)
|
||||
rotated = rotated.crop((20, 10, 148, 42))
|
||||
|
||||
return np.array(rotated), ''.join(digits)
|
||||
|
||||
def generate_train_dataset(num_samples=1000):
|
||||
os.makedirs('4digit_train', exist_ok=True)
|
||||
with open('4digit_train/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
# print(f"type of label : {type(label)}")
|
||||
label = str(label).zfill(4)
|
||||
img_path = f'4digit_train/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
|
||||
def generate_valid_dataset(num_samples=200):
|
||||
os.makedirs('4digit_valid', exist_ok=True)
|
||||
with open('4digit_valid/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
label = str(label).zfill(4)
|
||||
|
||||
img_path = f'4digit_valid/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
|
||||
def generate_eval_dataset(num_samples=200):
|
||||
os.makedirs('4digit_eval', exist_ok=True)
|
||||
with open('4digit_eval/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
label = str(label).zfill(4)
|
||||
img_path = f'4digit_eval/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
if __name__ == "__main__":
|
||||
generate_train_dataset()
|
||||
generate_eval_dataset()
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
X_RAND_VALUE = 2
|
||||
Y_RAND_VALUE = 1
|
||||
ROTATE_ANGLE = 3
|
||||
|
||||
|
||||
BG_COLORS = [
|
||||
(33, 40, 45), (36, 51, 62), (35, 37, 154),
|
||||
(0, 38, 202), (239, 255, 255), (241, 255, 255)
|
||||
]
|
||||
|
||||
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
|
||||
|
||||
def generate_4digit_image():
|
||||
bg_color = random.choice(BG_COLORS)
|
||||
font_size = random.randint(24, 30)
|
||||
font = ImageFont.truetype("arial.ttf", font_size)
|
||||
|
||||
# 扩大画布尺寸(50x160)提供足够缓冲空间
|
||||
canvas = np.zeros((50, 160, 3), dtype=np.uint8)
|
||||
canvas[:,:] = bg_color
|
||||
pil_img = Image.fromarray(canvas)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
digits = []
|
||||
for i in range(4):
|
||||
digit = str(random.randint(0, 9))
|
||||
digits.append(digit)
|
||||
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
|
||||
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
|
||||
digit_color = random.choice(DIGIT_COLORS)
|
||||
# 调整数字绘制位置到画布中心区域
|
||||
draw.text((20+i*32+x_offset, 12+y_offset), digit,
|
||||
font=font, fill=digit_color)
|
||||
|
||||
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
|
||||
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
|
||||
# 安全裁剪区域(从扩大后的画布中心裁剪)
|
||||
rotated = rotated.crop((20, 10, 148, 42))
|
||||
|
||||
return np.array(rotated), ''.join(digits)
|
||||
|
||||
def generate_train_dataset(num_samples=1000):
|
||||
os.makedirs('4digit_train', exist_ok=True)
|
||||
with open('4digit_train/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
# print(f"type of label : {type(label)}")
|
||||
label = str(label).zfill(4)
|
||||
img_path = f'4digit_train/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
|
||||
def generate_valid_dataset(num_samples=200):
|
||||
os.makedirs('4digit_valid', exist_ok=True)
|
||||
with open('4digit_valid/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
label = str(label).zfill(4)
|
||||
|
||||
img_path = f'4digit_valid/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
|
||||
def generate_eval_dataset(num_samples=200):
|
||||
os.makedirs('4digit_eval', exist_ok=True)
|
||||
with open('4digit_eval/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
label = str(label).zfill(4)
|
||||
img_path = f'4digit_eval/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
if __name__ == "__main__":
|
||||
generate_train_dataset()
|
||||
generate_eval_dataset()
|
||||
generate_valid_dataset()
|
||||
@@ -1,87 +1,87 @@
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
from model import Model
|
||||
import os
|
||||
import torch.backends.cudnn as cudnn
|
||||
import yaml
|
||||
from utils import AttrDict
|
||||
import pandas as pd
|
||||
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
||||
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
|
||||
def get_config(file_path):
|
||||
with open(file_path, 'r', encoding="utf8") as stream:
|
||||
opt = yaml.safe_load(stream)
|
||||
opt = AttrDict(opt)
|
||||
|
||||
if opt.lang_char == 'None' and opt.symbol=='None':
|
||||
opt.character = opt.number
|
||||
elif opt.lang_char == 'None':
|
||||
characters = ''
|
||||
for data in opt['select_data'].split('-'):
|
||||
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
|
||||
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
|
||||
all_char = ''.join(df['words'])
|
||||
characters += ''.join(set(all_char))
|
||||
characters = sorted(set(characters))
|
||||
opt.character= ''.join(characters)
|
||||
else:
|
||||
opt.character = opt.number + opt.symbol + opt.lang_char
|
||||
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
||||
if 'CTC' in opt.Prediction:
|
||||
converter = CTCLabelConverter(opt.character)
|
||||
else:
|
||||
converter = AttnLabelConverter(opt.character)
|
||||
opt.num_class = len(converter.character)
|
||||
print(f"converter.character : {converter.character}")
|
||||
print(f"字符集: {opt.character}")
|
||||
print(f"字符集长度: {opt.num_class}")
|
||||
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
||||
return opt
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
|
||||
parser.add_argument('--input', type=str, default='digit_cnn.pth',
|
||||
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
|
||||
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
|
||||
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
|
||||
parser.add_argument('--opset', type=int, default=11,
|
||||
help='ONNX算子集版本 (默认: 11)')
|
||||
return parser.parse_args()
|
||||
|
||||
def convert_to_onnx(input_path, output_path, opset_version,opt):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
if opt.rgb:
|
||||
opt.input_channel = 3
|
||||
model = Model(opt).to(device)
|
||||
torch.load(input_path)
|
||||
model.eval()
|
||||
|
||||
dummy_input = torch.randn(1, 3, 32, 128).to(device)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
output_path,
|
||||
export_params=True,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={
|
||||
'input': {0: 'batch_size'},
|
||||
'output': {0: 'batch_size'}
|
||||
}
|
||||
)
|
||||
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
opt = get_config("config_files/4digit_config.yaml")
|
||||
|
||||
convert_to_onnx(args.input, args.output, args.opset,opt)
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
from model import Model
|
||||
import os
|
||||
import torch.backends.cudnn as cudnn
|
||||
import yaml
|
||||
from utils import AttrDict
|
||||
import pandas as pd
|
||||
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
||||
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
|
||||
def get_config(file_path):
|
||||
with open(file_path, 'r', encoding="utf8") as stream:
|
||||
opt = yaml.safe_load(stream)
|
||||
opt = AttrDict(opt)
|
||||
|
||||
if opt.lang_char == 'None' and opt.symbol=='None':
|
||||
opt.character = opt.number
|
||||
elif opt.lang_char == 'None':
|
||||
characters = ''
|
||||
for data in opt['select_data'].split('-'):
|
||||
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
|
||||
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
|
||||
all_char = ''.join(df['words'])
|
||||
characters += ''.join(set(all_char))
|
||||
characters = sorted(set(characters))
|
||||
opt.character= ''.join(characters)
|
||||
else:
|
||||
opt.character = opt.number + opt.symbol + opt.lang_char
|
||||
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
||||
if 'CTC' in opt.Prediction:
|
||||
converter = CTCLabelConverter(opt.character)
|
||||
else:
|
||||
converter = AttnLabelConverter(opt.character)
|
||||
opt.num_class = len(converter.character)
|
||||
print(f"converter.character : {converter.character}")
|
||||
print(f"字符集: {opt.character}")
|
||||
print(f"字符集长度: {opt.num_class}")
|
||||
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
||||
return opt
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
|
||||
parser.add_argument('--input', type=str, default='digit_cnn.pth',
|
||||
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
|
||||
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
|
||||
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
|
||||
parser.add_argument('--opset', type=int, default=11,
|
||||
help='ONNX算子集版本 (默认: 11)')
|
||||
return parser.parse_args()
|
||||
|
||||
def convert_to_onnx(input_path, output_path, opset_version,opt):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
if opt.rgb:
|
||||
opt.input_channel = 3
|
||||
model = Model(opt).to(device)
|
||||
torch.load(input_path)
|
||||
model.eval()
|
||||
|
||||
dummy_input = torch.randn(1, 3, 32, 128).to(device)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
output_path,
|
||||
export_params=True,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={
|
||||
'input': {0: 'batch_size'},
|
||||
'output': {0: 'batch_size'}
|
||||
}
|
||||
)
|
||||
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
opt = get_config("config_files/4digit_config.yaml")
|
||||
|
||||
convert_to_onnx(args.input, args.output, args.opset,opt)
|
||||
|
||||
1
trainer/saved_models/.gitignore
vendored
Normal file
1
trainer/saved_models/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
./**
|
||||
Reference in New Issue
Block a user