4digit_training done
This commit is contained in:
@@ -10,11 +10,14 @@ def get_config(file_path):
|
||||
with open(file_path, 'r', encoding="utf8") as stream:
|
||||
opt = yaml.safe_load(stream)
|
||||
opt = AttrDict(opt)
|
||||
if opt.lang_char == 'None':
|
||||
|
||||
if opt.lang_char == 'None' and opt.symbol=='None':
|
||||
opt.character = opt.number
|
||||
elif opt.lang_char == 'None':
|
||||
characters = ''
|
||||
for data in opt['select_data'].split('-'):
|
||||
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
|
||||
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
|
||||
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
|
||||
all_char = ''.join(df['words'])
|
||||
characters += ''.join(set(all_char))
|
||||
characters = sorted(set(characters))
|
||||
@@ -26,6 +29,6 @@ def get_config(file_path):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load configuration
|
||||
# opt = get_config("config_files/4digit_config.yaml")
|
||||
opt = get_config("config_files/en_filtered_config.yaml")
|
||||
# opt = get_config("config_files/en_filtered_config.yaml")
|
||||
opt = get_config("config_files/4digit_config.yaml")
|
||||
train(opt, amp=False)
|
||||
Reference in New Issue
Block a user