add .gitignore in the alldata directory

This commit is contained in:
2025-07-21 18:37:41 +08:00
parent c10b0719c7
commit 5343c29e1b
16 changed files with 9860 additions and 4 deletions

2
.gitignore vendored
View File

@@ -107,4 +107,4 @@ ENV/
.vscode/
.vs/
.idea/
trainer/all_data/**

1
trainer/all_data/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
4digit*/**

BIN
trainer/all_data/arial.ttf Normal file

Binary file not shown.

View File

@@ -0,0 +1 @@
place dataset folder here

View File

@@ -0,0 +1,86 @@
import cv2
import numpy as np
import os
import random
from PIL import Image, ImageDraw, ImageFont
X_RAND_VALUE = 2
Y_RAND_VALUE = 1
ROTATE_ANGLE = 3
BG_COLORS = [
(33, 40, 45), (36, 51, 62), (35, 37, 154),
(0, 38, 202), (239, 255, 255), (241, 255, 255)
]
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
def generate_4digit_image():
bg_color = random.choice(BG_COLORS)
font_size = random.randint(24, 30)
font = ImageFont.truetype("arial.ttf", font_size)
# 扩大画布尺寸(50x160)提供足够缓冲空间
canvas = np.zeros((50, 160, 3), dtype=np.uint8)
canvas[:,:] = bg_color
pil_img = Image.fromarray(canvas)
draw = ImageDraw.Draw(pil_img)
digits = []
for i in range(4):
digit = str(random.randint(0, 9))
digits.append(digit)
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
digit_color = random.choice(DIGIT_COLORS)
# 调整数字绘制位置到画布中心区域
draw.text((20+i*32+x_offset, 12+y_offset), digit,
font=font, fill=digit_color)
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
# 安全裁剪区域(从扩大后的画布中心裁剪)
rotated = rotated.crop((20, 10, 148, 42))
return np.array(rotated), ''.join(digits)
def generate_train_dataset(num_samples=1000):
os.makedirs('4digit_train', exist_ok=True)
with open('4digit_train/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
# print(f"type of label : {type(label)}")
label = str(label).zfill(4)
img_path = f'4digit_train/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
def generate_valid_dataset(num_samples=200):
os.makedirs('4digit_valid', exist_ok=True)
with open('4digit_valid/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
label = str(label).zfill(4)
img_path = f'4digit_valid/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
def generate_eval_dataset(num_samples=200):
os.makedirs('4digit_eval', exist_ok=True)
with open('4digit_eval/labels.csv', 'w') as f:
f.write(f"filename,words\n")
for i in range(num_samples):
img, label = generate_4digit_image()
label = str(label).zfill(4)
img_path = f'4digit_eval/{i:04d}.jpg'
cv2.imwrite(img_path, img)
f.write(f"{i:04d}.jpg,{label}\n")
if __name__ == "__main__":
generate_train_dataset()
generate_eval_dataset()
generate_valid_dataset()

View File

@@ -0,0 +1,47 @@
import os
import shutil
import csv
def split_dataset(labels_path, img_source_dir, train_dir='train', valid_dir='valid'):
# 创建目标文件夹
os.makedirs(train_dir, exist_ok=True)
os.makedirs(valid_dir, exist_ok=True)
# 初始化CSV写入器
train_csv = open(os.path.join(train_dir, 'labels.csv'), 'w', newline='')
valid_csv = open(os.path.join(valid_dir, 'labels.csv'), 'w', newline='')
train_writer = csv.writer(train_csv)
valid_writer = csv.writer(valid_csv)
with open(labels_path, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
parts = line.strip().split(',')
img_name = parts[0].strip()
label = parts[1] if len(parts) > 1 else ''
src_path = os.path.join(img_source_dir, img_name)
print(f"处理图片: {img_name}, 标签: {label}")
if i < 700: # 训练集
dst_path = os.path.join(train_dir, img_name)
train_writer.writerow([img_name, label])
else: # 验证集
dst_path = os.path.join(valid_dir, img_name)
valid_writer.writerow([img_name, label])
if os.path.exists(src_path):
shutil.copy2(src_path, dst_path)
else:
print(f"警告:源图片不存在 {src_path}")
train_csv.close()
valid_csv.close()
# 使用示例
split_dataset(
labels_path='en_sample/labels.csv',
img_source_dir='en_sample'
)

View File

@@ -8,8 +8,8 @@ manualSeed: 1111
workers: 6
batch_size: 32 #32
num_iter: 3000
valInterval: 5
saved_model: '' #'saved_models/en_filtered/iter_300000.pth'
valInterval: 10
saved_model: 'saved_models/4digit/best_accuracy.pth'
FT: False
optim: False # default is Adadelta
lr: 1.

87
trainer/export_onnx.py Normal file
View File

@@ -0,0 +1,87 @@
import torch
import argparse
from model import Model
import os
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd
from utils import CTCLabelConverter, AttnLabelConverter, Averager
cudnn.benchmark = True
cudnn.deterministic = False
def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream)
opt = AttrDict(opt)
if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number
elif opt.lang_char == 'None':
characters = ''
for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words'])
characters += ''.join(set(all_char))
characters = sorted(set(characters))
opt.character= ''.join(characters)
else:
opt.character = opt.number + opt.symbol + opt.lang_char
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
print(f"converter.character : {converter.character}")
print(f"字符集: {opt.character}")
print(f"字符集长度: {opt.num_class}")
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
return opt
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch模型转ONNX格式工具')
parser.add_argument('--input', type=str, default='digit_cnn.pth',
help='输入PyTorch模型路径 (默认: digit_cnn.pth)')
parser.add_argument('--output', type=str, default='digit_cnn.onnx',
help='输出ONNX模型路径 (默认: digit_cnn.onnx)')
parser.add_argument('--opset', type=int, default=11,
help='ONNX算子集版本 (默认: 11)')
return parser.parse_args()
def convert_to_onnx(input_path, output_path, opset_version,opt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.rgb:
opt.input_channel = 3
model = Model(opt).to(device)
torch.load(input_path)
model.eval()
dummy_input = torch.randn(1, 3, 32, 128).to(device)
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"模型已成功转换为 {output_path} (opset {opset_version})")
if __name__ == '__main__':
args = parse_args()
opt = get_config("config_files/4digit_config.yaml")
convert_to_onnx(args.input, args.output, args.opset,opt)

View File

@@ -128,3 +128,138 @@ Total_batch_size: 32 = 32
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['4digit_train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root: all_data dataset: 4digit_train
sub-directory: /4digit_train num samples: 1000
num total samples of 4digit_train: 1000 x 1.0 (total_data_usage_ratio) = 1000
num samples of 4digit_train per batch: 32 x 1.0 (batch_ratio) = 32
--------------------------------------------------------------------------------
Total_batch_size: 32 = 32
--------------------------------------------------------------------------------
dataset_root: all_data/4digit_valid dataset: /
sub-directory: /. num samples: 200
--------------------------------------------------------------------------------

File diff suppressed because it is too large Load Diff

View File

@@ -90,3 +90,417 @@ freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model:
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model:
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model:
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model:
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model:
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model: saved_models/4digit/best_accuracy.pth
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model: saved_models/4digit/best_accuracy.pth
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 5
saved_model: saved_models/4digit/best_accuracy.pth
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------
------------ Options -------------
number: 0123456789
experiment_name: 4digit
symbol: None
lang_char: None
train_data: all_data
valid_data: all_data/4digit_valid
manualSeed: 1111
workers: 6
batch_size: 32
num_iter: 3000
valInterval: 10
saved_model: saved_models/4digit/best_accuracy.pth
FT: False
optim: False
lr: 1.0
beta1: 0.9
rho: 0.95
eps: 1e-08
grad_clip: 5
select_data: ['4digit_train']
batch_ratio: ['1']
total_data_usage_ratio: 1.0
batch_max_length: 34
imgH: 32
imgW: 128
rgb: True
contrast_adjust: 0.0
sensitive: True
PAD: True
data_filtering_off: False
Transformation: TPS
FeatureExtraction: ResNet
SequenceModeling: BiLSTM
Prediction: CTC
num_fiducial: 20
input_channel: 3
output_channel: 256
hidden_size: 256
decode: greedy
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False
character: 0123456789
num_class: 11
---------------------------------------

View File

@@ -29,12 +29,12 @@ def validation(model, criterion, evaluation_loader, converter, opt, device):
# For max length prediction
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
start_time = time.time()
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred)
# print(f"preds shape : {preds.shape}")
forward_time = time.time() - start_time
# Calculate evaluation loss for CTC decoder.

View File

@@ -213,6 +213,7 @@ def train(opt, show_number = 2, amp=False):
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2)
print(f"preds shape : {preds.shape}")
torch.backends.cudnn.enabled = False
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True
@@ -265,6 +266,9 @@ def train(opt, show_number = 2, amp=False):
#show_number = min(show_number, len(labels))
start = random.randint(0,len(labels) - show_number )
print(f"start index for showing results: {start}")
print(f"labels length: {len(labels)}")
print(f"labels : {labels}")
for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
if 'Attn' in opt.Prediction:
gt = gt[:gt.find('[s]')]