import os import shutil import csv def split_dataset(labels_path, img_source_dir, train_dir='train', valid_dir='valid'): # 创建目标文件夹 os.makedirs(train_dir, exist_ok=True) os.makedirs(valid_dir, exist_ok=True) # 初始化CSV写入器 train_csv = open(os.path.join(train_dir, 'labels.csv'), 'w', newline='') valid_csv = open(os.path.join(valid_dir, 'labels.csv'), 'w', newline='') train_writer = csv.writer(train_csv) valid_writer = csv.writer(valid_csv) with open(labels_path, 'r') as f: lines = f.readlines() for i, line in enumerate(lines): parts = line.strip().split(',') img_name = parts[0].strip() label = parts[1] if len(parts) > 1 else '' src_path = os.path.join(img_source_dir, img_name) print(f"处理图片: {img_name}, 标签: {label}") if i < 700: # 训练集 dst_path = os.path.join(train_dir, img_name) train_writer.writerow([img_name, label]) else: # 验证集 dst_path = os.path.join(valid_dir, img_name) valid_writer.writerow([img_name, label]) if os.path.exists(src_path): shutil.copy2(src_path, dst_path) else: print(f"警告:源图片不存在 {src_path}") train_csv.close() valid_csv.close() # 使用示例 split_dataset( labels_path='en_sample/labels.csv', img_source_dir='en_sample' )