4digit_training done

This commit is contained in:
HT
2025-07-11 12:21:59 +08:00
parent cfa52e5f2e
commit c10b0719c7
11 changed files with 5770 additions and 12 deletions

View File

@@ -10,11 +10,14 @@ def get_config(file_path):
with open(file_path, 'r', encoding="utf8") as stream:
opt = yaml.safe_load(stream)
opt = AttrDict(opt)
if opt.lang_char == 'None':
if opt.lang_char == 'None' and opt.symbol=='None':
opt.character = opt.number
elif opt.lang_char == 'None':
characters = ''
for data in opt['select_data'].split('-'):
csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python',dtype={'words': str}, usecols=['filename', 'words'], keep_default_na=False)
all_char = ''.join(df['words'])
characters += ''.join(set(all_char))
characters = sorted(set(characters))
@@ -26,6 +29,6 @@ def get_config(file_path):
if __name__ == "__main__":
# Load configuration
# opt = get_config("config_files/4digit_config.yaml")
opt = get_config("config_files/en_filtered_config.yaml")
# opt = get_config("config_files/en_filtered_config.yaml")
opt = get_config("config_files/4digit_config.yaml")
train(opt, amp=False)