testing dataset
This commit is contained in:
30
trainer/trainer.py
Normal file
30
trainer/trainer.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
import torch.backends.cudnn as cudnn
|
||||
import yaml
|
||||
from train import train
|
||||
from utils import AttrDict
|
||||
import pandas as pd
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
def get_config(file_path):
|
||||
with open(file_path, 'r', encoding="utf8") as stream:
|
||||
opt = yaml.safe_load(stream)
|
||||
opt = AttrDict(opt)
|
||||
if opt.lang_char == 'None':
|
||||
characters = ''
|
||||
for data in opt['select_data'].split('-'):
|
||||
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
|
||||
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
|
||||
all_char = ''.join(df['words'])
|
||||
characters += ''.join(set(all_char))
|
||||
characters = sorted(set(characters))
|
||||
opt.character= ''.join(characters)
|
||||
else:
|
||||
opt.character = opt.number + opt.symbol + opt.lang_char
|
||||
os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
|
||||
return opt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load configuration
|
||||
opt = get_config("config_files/en_filtered_config.yaml")
|
||||
train(opt, amp=False)
|
||||
Reference in New Issue
Block a user