48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
|
|
import os
|
|
import shutil
|
|
import csv
|
|
|
|
def split_dataset(labels_path, img_source_dir, train_dir='train', valid_dir='valid'):
|
|
# 创建目标文件夹
|
|
os.makedirs(train_dir, exist_ok=True)
|
|
os.makedirs(valid_dir, exist_ok=True)
|
|
|
|
# 初始化CSV写入器
|
|
train_csv = open(os.path.join(train_dir, 'labels.csv'), 'w', newline='')
|
|
valid_csv = open(os.path.join(valid_dir, 'labels.csv'), 'w', newline='')
|
|
train_writer = csv.writer(train_csv)
|
|
valid_writer = csv.writer(valid_csv)
|
|
|
|
with open(labels_path, 'r') as f:
|
|
lines = f.readlines()
|
|
|
|
for i, line in enumerate(lines):
|
|
parts = line.strip().split(',')
|
|
img_name = parts[0].strip()
|
|
|
|
|
|
label = parts[1] if len(parts) > 1 else ''
|
|
src_path = os.path.join(img_source_dir, img_name)
|
|
print(f"处理图片: {img_name}, 标签: {label}")
|
|
if i < 700: # 训练集
|
|
dst_path = os.path.join(train_dir, img_name)
|
|
train_writer.writerow([img_name, label])
|
|
else: # 验证集
|
|
dst_path = os.path.join(valid_dir, img_name)
|
|
valid_writer.writerow([img_name, label])
|
|
|
|
if os.path.exists(src_path):
|
|
shutil.copy2(src_path, dst_path)
|
|
else:
|
|
print(f"警告:源图片不存在 {src_path}")
|
|
|
|
train_csv.close()
|
|
valid_csv.close()
|
|
|
|
# 使用示例
|
|
split_dataset(
|
|
labels_path='en_sample/labels.csv',
|
|
img_source_dir='en_sample'
|
|
)
|