add .gitignore in the alldata directory
This commit is contained in:
47
trainer/all_data/split_dataset.py
Normal file
47
trainer/all_data/split_dataset.py
Normal file
@@ -0,0 +1,47 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import csv
|
||||
|
||||
def split_dataset(labels_path, img_source_dir, train_dir='train', valid_dir='valid'):
|
||||
# 创建目标文件夹
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
os.makedirs(valid_dir, exist_ok=True)
|
||||
|
||||
# 初始化CSV写入器
|
||||
train_csv = open(os.path.join(train_dir, 'labels.csv'), 'w', newline='')
|
||||
valid_csv = open(os.path.join(valid_dir, 'labels.csv'), 'w', newline='')
|
||||
train_writer = csv.writer(train_csv)
|
||||
valid_writer = csv.writer(valid_csv)
|
||||
|
||||
with open(labels_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
parts = line.strip().split(',')
|
||||
img_name = parts[0].strip()
|
||||
|
||||
|
||||
label = parts[1] if len(parts) > 1 else ''
|
||||
src_path = os.path.join(img_source_dir, img_name)
|
||||
print(f"处理图片: {img_name}, 标签: {label}")
|
||||
if i < 700: # 训练集
|
||||
dst_path = os.path.join(train_dir, img_name)
|
||||
train_writer.writerow([img_name, label])
|
||||
else: # 验证集
|
||||
dst_path = os.path.join(valid_dir, img_name)
|
||||
valid_writer.writerow([img_name, label])
|
||||
|
||||
if os.path.exists(src_path):
|
||||
shutil.copy2(src_path, dst_path)
|
||||
else:
|
||||
print(f"警告:源图片不存在 {src_path}")
|
||||
|
||||
train_csv.close()
|
||||
valid_csv.close()
|
||||
|
||||
# 使用示例
|
||||
split_dataset(
|
||||
labels_path='en_sample/labels.csv',
|
||||
img_source_dir='en_sample'
|
||||
)
|
||||
Reference in New Issue
Block a user