Files
easyocr/trainer/all_data/split_dataset.py

48 lines
1.4 KiB
Python

import os
import shutil
import csv
def split_dataset(labels_path, img_source_dir, train_dir='train', valid_dir='valid'):
# 创建目标文件夹
os.makedirs(train_dir, exist_ok=True)
os.makedirs(valid_dir, exist_ok=True)
# 初始化CSV写入器
train_csv = open(os.path.join(train_dir, 'labels.csv'), 'w', newline='')
valid_csv = open(os.path.join(valid_dir, 'labels.csv'), 'w', newline='')
train_writer = csv.writer(train_csv)
valid_writer = csv.writer(valid_csv)
with open(labels_path, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
parts = line.strip().split(',')
img_name = parts[0].strip()
label = parts[1] if len(parts) > 1 else ''
src_path = os.path.join(img_source_dir, img_name)
print(f"处理图片: {img_name}, 标签: {label}")
if i < 700: # 训练集
dst_path = os.path.join(train_dir, img_name)
train_writer.writerow([img_name, label])
else: # 验证集
dst_path = os.path.join(valid_dir, img_name)
valid_writer.writerow([img_name, label])
if os.path.exists(src_path):
shutil.copy2(src_path, dst_path)
else:
print(f"警告:源图片不存在 {src_path}")
train_csv.close()
valid_csv.close()
# 使用示例
split_dataset(
labels_path='en_sample/labels.csv',
img_source_dir='en_sample'
)