testing dataset ok; trainer.py take effects
This commit is contained in:
46
trainer/config_files/4digit_config.yaml
Normal file
46
trainer/config_files/4digit_config.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
number: '0123456789'
|
||||
experiment_name: '4digit'
|
||||
symbol: ""
|
||||
lang_char: ''
|
||||
train_data: 'all_data'
|
||||
valid_data: 'all_data/4digit_valid'
|
||||
manualSeed: 1111
|
||||
workers: 6
|
||||
batch_size: 32 #32
|
||||
num_iter: 3000
|
||||
valInterval: 5
|
||||
# saved_model: '' #'saved_models/en_filtered/iter_300000.pth'
|
||||
svaed_model: 'saved_models/4digit/iter_3000.pth'
|
||||
FT: False
|
||||
optim: False # default is Adadelta
|
||||
lr: 1.
|
||||
beta1: 0.9
|
||||
rho: 0.95
|
||||
eps: 0.00000001
|
||||
grad_clip: 5
|
||||
#Data processing
|
||||
select_data: '4digit_train' # this is dataset folder in train_data
|
||||
batch_ratio: '1'
|
||||
total_data_usage_ratio: 1.0
|
||||
batch_max_length: 34
|
||||
imgH: 32
|
||||
imgW: 128
|
||||
rgb: True
|
||||
contrast_adjust: False
|
||||
sensitive: True
|
||||
PAD: True
|
||||
contrast_adjust: 0.0
|
||||
data_filtering_off: False
|
||||
# Model Architecture
|
||||
Transformation: 'TPS'
|
||||
FeatureExtraction: 'ResNet'
|
||||
SequenceModeling: 'BiLSTM'
|
||||
Prediction: 'CTC'
|
||||
num_fiducial: 20
|
||||
input_channel: 1
|
||||
output_channel: 256
|
||||
hidden_size: 256
|
||||
decode: 'greedy'
|
||||
new_prediction: False
|
||||
freeze_FeatureFxtraction: False
|
||||
freeze_SequenceModeling: False
|
||||
@@ -7,8 +7,8 @@ valid_data: 'all_data/valid'
|
||||
manualSeed: 1111
|
||||
workers: 6
|
||||
batch_size: 32 #32
|
||||
num_iter: 300000
|
||||
valInterval: 20000
|
||||
num_iter: 300
|
||||
valInterval: 5
|
||||
saved_model: '' #'saved_models/en_filtered/iter_300000.pth'
|
||||
FT: False
|
||||
optim: False # default is Adadelta
|
||||
@@ -31,8 +31,8 @@ PAD: True
|
||||
contrast_adjust: 0.0
|
||||
data_filtering_off: False
|
||||
# Model Architecture
|
||||
Transformation: 'ResNet'
|
||||
FeatureExtraction: 'VGG'
|
||||
Transformation: 'None'
|
||||
FeatureExtraction: 'ResNet'
|
||||
SequenceModeling: 'BiLSTM'
|
||||
Prediction: 'CTC'
|
||||
num_fiducial: 20
|
||||
|
||||
@@ -42,7 +42,6 @@ class Batch_Balanced_Dataset(object):
|
||||
log.write(dashed_line + '\n')
|
||||
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
|
||||
log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
|
||||
print(f"len(opt.select_data): {len(opt.select_data)}, len(opt.batch_ratio): {len(opt.batch_ratio)}")
|
||||
assert len(opt.select_data) == len(opt.batch_ratio)
|
||||
|
||||
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust = opt.contrast_adjust)
|
||||
@@ -54,7 +53,6 @@ class Batch_Balanced_Dataset(object):
|
||||
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
print(f"selected_d: {selected_d}, batch_ratio: {batch_ratio_d}, batch_size: {_batch_size}")
|
||||
_dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
|
||||
total_number_dataset = len(_dataset)
|
||||
log.write(_dataset_log)
|
||||
@@ -100,12 +98,12 @@ class Batch_Balanced_Dataset(object):
|
||||
|
||||
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
|
||||
try:
|
||||
image, text = data_loader_iter.next()
|
||||
image, text = next(data_loader_iter)
|
||||
balanced_batch_images.append(image)
|
||||
balanced_batch_texts += text
|
||||
except StopIteration:
|
||||
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
|
||||
image, text = self.dataloader_iter_list[i].next()
|
||||
image, text = next(self.dataloader_iter_list[i])
|
||||
balanced_batch_images.append(image)
|
||||
balanced_batch_texts += text
|
||||
except ValueError:
|
||||
@@ -119,7 +117,7 @@ class Batch_Balanced_Dataset(object):
|
||||
def hierarchical_dataset(root, opt, select_data='/'):
|
||||
""" select_data='/' contains all sub-directory of root directory """
|
||||
dataset_list = []
|
||||
dataset_log = f'dataset_root: {root}\t dataset[0]: {select_data[0]}'
|
||||
dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}'
|
||||
print(dataset_log)
|
||||
dataset_log += '\n'
|
||||
for dirpath, dirnames, filenames in os.walk(root+'/'):
|
||||
@@ -148,12 +146,8 @@ class OCRDataset(Dataset):
|
||||
self.root = root
|
||||
self.opt = opt
|
||||
print(root)
|
||||
print(f"Loading dataset from {root}...")
|
||||
print(opt)
|
||||
self.df = pd.read_csv(os.path.join(root,'labels.csv'), sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
|
||||
|
||||
self.nSamples = len(self.df)
|
||||
print(f"Number of samples: {self.nSamples}")
|
||||
|
||||
if self.opt.data_filtering_off:
|
||||
self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
|
||||
@@ -286,4 +280,4 @@ def tensor2im(image_tensor, imtype=np.uint8):
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
image_pil.save(image_path)
|
||||
BIN
trainer/saved_models/en_filtered/best_accuracy.pth
Normal file
BIN
trainer/saved_models/en_filtered/best_accuracy.pth
Normal file
Binary file not shown.
BIN
trainer/saved_models/en_filtered/best_norm_ED.pth
Normal file
BIN
trainer/saved_models/en_filtered/best_norm_ED.pth
Normal file
Binary file not shown.
@@ -1,102 +1,30 @@
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/en_sample
|
||||
opt.select_data: ['all_data/en_sample/train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/en_sample
|
||||
opt.select_data: ['all_data/en_sample/train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/en_sample
|
||||
opt.select_data: ['all_data/en_sample/train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/en_sample
|
||||
opt.select_data: ['all_data/en_sample/train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/en_sample
|
||||
opt.select_data: ['all_data/en_sample/train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/en_sample
|
||||
opt.select_data: ['all_data/en_sample/train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/
|
||||
dataset_root: all_data
|
||||
opt.select_data: ['train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/ dataset[0]: train
|
||||
dataset_root: all_data dataset: train
|
||||
sub-directory: /train num samples: 688
|
||||
num total samples of train: 688 x 1.0 (total_data_usage_ratio) = 688
|
||||
num samples of train per batch: 32 x 1.0 (batch_ratio) = 32
|
||||
--------------------------------------------------------------------------------
|
||||
Total_batch_size: 32 = 32
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data/valid dataset: /
|
||||
sub-directory: /. num samples: 194
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data
|
||||
opt.select_data: ['train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data dataset[0]: train
|
||||
dataset_root: all_data dataset: train
|
||||
sub-directory: /train num samples: 688
|
||||
num total samples of train: 688 x 1.0 (total_data_usage_ratio) = 688
|
||||
num samples of train per batch: 32 x 1.0 (batch_ratio) = 32
|
||||
--------------------------------------------------------------------------------
|
||||
Total_batch_size: 32 = 32
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data
|
||||
opt.select_data: ['train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data dataset[0]: train
|
||||
sub-directory: /train num samples: 688
|
||||
num total samples of train: 688 x 1.0 (total_data_usage_ratio) = 688
|
||||
num samples of train per batch: 32 x 1.0 (batch_ratio) = 32
|
||||
--------------------------------------------------------------------------------
|
||||
Total_batch_size: 32 = 32
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data
|
||||
opt.select_data: ['train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data dataset[0]: train
|
||||
sub-directory: /train num samples: 688
|
||||
num total samples of train: 688 x 1.0 (total_data_usage_ratio) = 688
|
||||
num samples of train per batch: 32 x 1.0 (batch_ratio) = 32
|
||||
--------------------------------------------------------------------------------
|
||||
Total_batch_size: 32 = 32
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data
|
||||
opt.select_data: ['train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data dataset[0]: train
|
||||
sub-directory: /train num samples: 688
|
||||
num total samples of train: 688 x 1.0 (total_data_usage_ratio) = 688
|
||||
num samples of train per batch: 32 x 1.0 (batch_ratio) = 32
|
||||
--------------------------------------------------------------------------------
|
||||
Total_batch_size: 32 = 32
|
||||
--------------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data
|
||||
opt.select_data: ['train']
|
||||
opt.batch_ratio: ['1']
|
||||
--------------------------------------------------------------------------------
|
||||
dataset_root: all_data dataset[0]: train
|
||||
sub-directory: /train num samples: 688
|
||||
num total samples of train: 688 x 1.0 (total_data_usage_ratio) = 688
|
||||
num samples of train per batch: 32 x 1.0 (batch_ratio) = 32
|
||||
--------------------------------------------------------------------------------
|
||||
Total_batch_size: 32 = 32
|
||||
dataset_root: all_data/valid dataset: /
|
||||
sub-directory: /. num samples: 194
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
9
trainer/saved_models/en_filtered/log_train.txt
Normal file
9
trainer/saved_models/en_filtered/log_train.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
[5/300] Train loss: 28.86501, Valid loss: 21.44183, Elapsed_time: 3.73978
|
||||
Current_accuracy : 0.000, Current_norm_ED : 0.0445
|
||||
Best_accuracy : 0.000, Best_norm_ED : 0.0445
|
||||
--------------------------------------------------------------------------------
|
||||
Ground Truth | Prediction | Confidence Score & T/F
|
||||
--------------------------------------------------------------------------------
|
||||
"Karjalan outcrop ""Sven""" | eee | 0.0000 False
|
||||
ESPN Modernize | ee | 0.0000 False
|
||||
--------------------------------------------------------------------------------
|
||||
92
trainer/saved_models/en_filtered/opt.txt
Normal file
92
trainer/saved_models/en_filtered/opt.txt
Normal file
@@ -0,0 +1,92 @@
|
||||
------------ Options -------------
|
||||
number: 0123456789
|
||||
symbol: !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ €
|
||||
lang_char: ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
|
||||
experiment_name: en_filtered
|
||||
train_data: all_data
|
||||
valid_data: all_data/valid
|
||||
manualSeed: 1111
|
||||
workers: 6
|
||||
batch_size: 32
|
||||
num_iter: 300
|
||||
valInterval: 5
|
||||
saved_model:
|
||||
FT: False
|
||||
optim: False
|
||||
lr: 1.0
|
||||
beta1: 0.9
|
||||
rho: 0.95
|
||||
eps: 1e-08
|
||||
grad_clip: 5
|
||||
select_data: ['train']
|
||||
batch_ratio: ['1']
|
||||
total_data_usage_ratio: 1.0
|
||||
batch_max_length: 34
|
||||
imgH: 64
|
||||
imgW: 600
|
||||
rgb: False
|
||||
contrast_adjust: 0.0
|
||||
sensitive: True
|
||||
PAD: True
|
||||
data_filtering_off: False
|
||||
Transformation: None
|
||||
FeatureExtraction: ResNet
|
||||
SequenceModeling: BiLSTM
|
||||
Prediction: CTC
|
||||
num_fiducial: 20
|
||||
input_channel: 1
|
||||
output_channel: 256
|
||||
hidden_size: 256
|
||||
decode: greedy
|
||||
new_prediction: False
|
||||
freeze_FeatureFxtraction: False
|
||||
freeze_SequenceModeling: False
|
||||
character: 0123456789!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
|
||||
num_class: 97
|
||||
---------------------------------------
|
||||
------------ Options -------------
|
||||
number: 0123456789
|
||||
symbol: !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ €
|
||||
lang_char: ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
|
||||
experiment_name: en_filtered
|
||||
train_data: all_data
|
||||
valid_data: all_data/valid
|
||||
manualSeed: 1111
|
||||
workers: 6
|
||||
batch_size: 32
|
||||
num_iter: 300
|
||||
valInterval: 5
|
||||
saved_model:
|
||||
FT: False
|
||||
optim: False
|
||||
lr: 1.0
|
||||
beta1: 0.9
|
||||
rho: 0.95
|
||||
eps: 1e-08
|
||||
grad_clip: 5
|
||||
select_data: ['train']
|
||||
batch_ratio: ['1']
|
||||
total_data_usage_ratio: 1.0
|
||||
batch_max_length: 34
|
||||
imgH: 64
|
||||
imgW: 600
|
||||
rgb: False
|
||||
contrast_adjust: 0.0
|
||||
sensitive: True
|
||||
PAD: True
|
||||
data_filtering_off: False
|
||||
Transformation: None
|
||||
FeatureExtraction: ResNet
|
||||
SequenceModeling: BiLSTM
|
||||
Prediction: CTC
|
||||
num_fiducial: 20
|
||||
input_channel: 1
|
||||
output_channel: 256
|
||||
hidden_size: 256
|
||||
decode: greedy
|
||||
new_prediction: False
|
||||
freeze_FeatureFxtraction: False
|
||||
freeze_SequenceModeling: False
|
||||
character: 0123456789!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
|
||||
num_class: 97
|
||||
---------------------------------------
|
||||
@@ -68,7 +68,7 @@ def train(opt, show_number = 2, amp=False):
|
||||
model = Model(opt)
|
||||
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
||||
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
||||
opt.SequenceModeling, opt.Prediction)
|
||||
opt.SequenceModeling, opt.Prediction, opt.saved_model)
|
||||
|
||||
if opt.saved_model != '':
|
||||
pretrained_dict = torch.load(opt.saved_model)
|
||||
|
||||
@@ -93,7 +93,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "easyocr",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -107,7 +107,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.11"
|
||||
"version": "3.8.20"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -26,5 +26,6 @@ def get_config(file_path):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load configuration
|
||||
# opt = get_config("config_files/4digit_config.yaml")
|
||||
opt = get_config("config_files/en_filtered_config.yaml")
|
||||
train(opt, amp=False)
|
||||
Reference in New Issue
Block a user