testing dataset ok; trainer.py take effects
This commit is contained in:
@@ -42,7 +42,6 @@ class Batch_Balanced_Dataset(object):
|
||||
log.write(dashed_line + '\n')
|
||||
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
|
||||
log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
|
||||
print(f"len(opt.select_data): {len(opt.select_data)}, len(opt.batch_ratio): {len(opt.batch_ratio)}")
|
||||
assert len(opt.select_data) == len(opt.batch_ratio)
|
||||
|
||||
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust = opt.contrast_adjust)
|
||||
@@ -54,7 +53,6 @@ class Batch_Balanced_Dataset(object):
|
||||
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
print(f"selected_d: {selected_d}, batch_ratio: {batch_ratio_d}, batch_size: {_batch_size}")
|
||||
_dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
|
||||
total_number_dataset = len(_dataset)
|
||||
log.write(_dataset_log)
|
||||
@@ -100,12 +98,12 @@ class Batch_Balanced_Dataset(object):
|
||||
|
||||
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
|
||||
try:
|
||||
image, text = data_loader_iter.next()
|
||||
image, text = next(data_loader_iter)
|
||||
balanced_batch_images.append(image)
|
||||
balanced_batch_texts += text
|
||||
except StopIteration:
|
||||
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
|
||||
image, text = self.dataloader_iter_list[i].next()
|
||||
image, text = next(self.dataloader_iter_list[i])
|
||||
balanced_batch_images.append(image)
|
||||
balanced_batch_texts += text
|
||||
except ValueError:
|
||||
@@ -119,7 +117,7 @@ class Batch_Balanced_Dataset(object):
|
||||
def hierarchical_dataset(root, opt, select_data='/'):
|
||||
""" select_data='/' contains all sub-directory of root directory """
|
||||
dataset_list = []
|
||||
dataset_log = f'dataset_root: {root}\t dataset[0]: {select_data[0]}'
|
||||
dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}'
|
||||
print(dataset_log)
|
||||
dataset_log += '\n'
|
||||
for dirpath, dirnames, filenames in os.walk(root+'/'):
|
||||
@@ -148,12 +146,8 @@ class OCRDataset(Dataset):
|
||||
self.root = root
|
||||
self.opt = opt
|
||||
print(root)
|
||||
print(f"Loading dataset from {root}...")
|
||||
print(opt)
|
||||
self.df = pd.read_csv(os.path.join(root,'labels.csv'), sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
|
||||
|
||||
self.nSamples = len(self.df)
|
||||
print(f"Number of samples: {self.nSamples}")
|
||||
|
||||
if self.opt.data_filtering_off:
|
||||
self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
|
||||
@@ -286,4 +280,4 @@ def tensor2im(image_tensor, imtype=np.uint8):
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
image_pil.save(image_path)
|
||||
Reference in New Issue
Block a user