{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:19:23.488642Z", "start_time": "2021-07-23T04:19:21.854534Z" } }, "outputs": [], "source": [ "import os\n", "import torch.backends.cudnn as cudnn\n", "import yaml\n", "from train import train\n", "from utils import AttrDict\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:19:23.885144Z", "start_time": "2021-07-23T04:19:23.880564Z" }, "code_folding": [] }, "outputs": [], "source": [ "cudnn.benchmark = True\n", "cudnn.deterministic = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:19:24.119144Z", "start_time": "2021-07-23T04:19:24.112032Z" }, "code_folding": [ 0 ] }, "outputs": [], "source": [ "def get_config(file_path):\n", " with open(file_path, 'r', encoding=\"utf8\") as stream:\n", " opt = yaml.safe_load(stream)\n", " opt = AttrDict(opt)\n", " if opt.lang_char == 'None':\n", " characters = ''\n", " for data in opt['select_data'].split('-'):\n", " csv_path = os.path.join(opt['train_data'], data, 'labels.csv')\n", " df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)\n", " all_char = ''.join(df['words'])\n", " characters += ''.join(set(all_char))\n", " characters = sorted(set(characters))\n", " opt.character= ''.join(characters)\n", " else:\n", " opt.character = opt.number + opt.symbol + opt.lang_char\n", " os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)\n", " return opt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:49:07.045060Z", "start_time": "2021-07-23T04:20:15.050992Z" } }, "outputs": [], "source": [ "opt = get_config(\"config_files/en_filtered_config.yaml\")\n", "train(opt, amp=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 2 }