add .gitignore in the alldata directory
This commit is contained in:
1
trainer/all_data/.gitignore
vendored
Normal file
1
trainer/all_data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
4digit*/**
|
||||
BIN
trainer/all_data/arial.ttf
Normal file
BIN
trainer/all_data/arial.ttf
Normal file
Binary file not shown.
1
trainer/all_data/folder.txt
Normal file
1
trainer/all_data/folder.txt
Normal file
@@ -0,0 +1 @@
|
||||
place dataset folder here
|
||||
86
trainer/all_data/generate_digits_random_fs_bg_fg.py
Normal file
86
trainer/all_data/generate_digits_random_fs_bg_fg.py
Normal file
@@ -0,0 +1,86 @@
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
X_RAND_VALUE = 2
|
||||
Y_RAND_VALUE = 1
|
||||
ROTATE_ANGLE = 3
|
||||
|
||||
|
||||
BG_COLORS = [
|
||||
(33, 40, 45), (36, 51, 62), (35, 37, 154),
|
||||
(0, 38, 202), (239, 255, 255), (241, 255, 255)
|
||||
]
|
||||
|
||||
DIGIT_COLORS = [(34, 199, 253), (25, 214, 253)]
|
||||
|
||||
def generate_4digit_image():
|
||||
bg_color = random.choice(BG_COLORS)
|
||||
font_size = random.randint(24, 30)
|
||||
font = ImageFont.truetype("arial.ttf", font_size)
|
||||
|
||||
# 扩大画布尺寸(50x160)提供足够缓冲空间
|
||||
canvas = np.zeros((50, 160, 3), dtype=np.uint8)
|
||||
canvas[:,:] = bg_color
|
||||
pil_img = Image.fromarray(canvas)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
digits = []
|
||||
for i in range(4):
|
||||
digit = str(random.randint(0, 9))
|
||||
digits.append(digit)
|
||||
x_offset = random.randint(-X_RAND_VALUE, X_RAND_VALUE)
|
||||
y_offset = random.randint(-Y_RAND_VALUE, Y_RAND_VALUE)
|
||||
digit_color = random.choice(DIGIT_COLORS)
|
||||
# 调整数字绘制位置到画布中心区域
|
||||
draw.text((20+i*32+x_offset, 12+y_offset), digit,
|
||||
font=font, fill=digit_color)
|
||||
|
||||
angle = random.uniform(-ROTATE_ANGLE, ROTATE_ANGLE)
|
||||
rotated = pil_img.rotate(angle, expand=True, fillcolor=bg_color)
|
||||
# 安全裁剪区域(从扩大后的画布中心裁剪)
|
||||
rotated = rotated.crop((20, 10, 148, 42))
|
||||
|
||||
return np.array(rotated), ''.join(digits)
|
||||
|
||||
def generate_train_dataset(num_samples=1000):
|
||||
os.makedirs('4digit_train', exist_ok=True)
|
||||
with open('4digit_train/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
# print(f"type of label : {type(label)}")
|
||||
label = str(label).zfill(4)
|
||||
img_path = f'4digit_train/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
|
||||
def generate_valid_dataset(num_samples=200):
|
||||
os.makedirs('4digit_valid', exist_ok=True)
|
||||
with open('4digit_valid/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
label = str(label).zfill(4)
|
||||
|
||||
img_path = f'4digit_valid/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
|
||||
def generate_eval_dataset(num_samples=200):
|
||||
os.makedirs('4digit_eval', exist_ok=True)
|
||||
with open('4digit_eval/labels.csv', 'w') as f:
|
||||
f.write(f"filename,words\n")
|
||||
for i in range(num_samples):
|
||||
img, label = generate_4digit_image()
|
||||
label = str(label).zfill(4)
|
||||
img_path = f'4digit_eval/{i:04d}.jpg'
|
||||
cv2.imwrite(img_path, img)
|
||||
f.write(f"{i:04d}.jpg,{label}\n")
|
||||
if __name__ == "__main__":
|
||||
generate_train_dataset()
|
||||
generate_eval_dataset()
|
||||
generate_valid_dataset()
|
||||
47
trainer/all_data/split_dataset.py
Normal file
47
trainer/all_data/split_dataset.py
Normal file
@@ -0,0 +1,47 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import csv
|
||||
|
||||
def split_dataset(labels_path, img_source_dir, train_dir='train', valid_dir='valid'):
|
||||
# 创建目标文件夹
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
os.makedirs(valid_dir, exist_ok=True)
|
||||
|
||||
# 初始化CSV写入器
|
||||
train_csv = open(os.path.join(train_dir, 'labels.csv'), 'w', newline='')
|
||||
valid_csv = open(os.path.join(valid_dir, 'labels.csv'), 'w', newline='')
|
||||
train_writer = csv.writer(train_csv)
|
||||
valid_writer = csv.writer(valid_csv)
|
||||
|
||||
with open(labels_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
parts = line.strip().split(',')
|
||||
img_name = parts[0].strip()
|
||||
|
||||
|
||||
label = parts[1] if len(parts) > 1 else ''
|
||||
src_path = os.path.join(img_source_dir, img_name)
|
||||
print(f"处理图片: {img_name}, 标签: {label}")
|
||||
if i < 700: # 训练集
|
||||
dst_path = os.path.join(train_dir, img_name)
|
||||
train_writer.writerow([img_name, label])
|
||||
else: # 验证集
|
||||
dst_path = os.path.join(valid_dir, img_name)
|
||||
valid_writer.writerow([img_name, label])
|
||||
|
||||
if os.path.exists(src_path):
|
||||
shutil.copy2(src_path, dst_path)
|
||||
else:
|
||||
print(f"警告:源图片不存在 {src_path}")
|
||||
|
||||
train_csv.close()
|
||||
valid_csv.close()
|
||||
|
||||
# 使用示例
|
||||
split_dataset(
|
||||
labels_path='en_sample/labels.csv',
|
||||
img_source_dir='en_sample'
|
||||
)
|
||||
Reference in New Issue
Block a user