Files
easyocr/trainer/craft/eval.py

382 lines
12 KiB
Python
Raw Normal View History

2025-07-10 19:42:57 +08:00
# -*- coding: utf-8 -*-
import argparse
import os
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm
import wandb
from config.load_config import load_yaml, DotDict
from model.craft import CRAFT
from metrics.eval_det_iou import DetectionIoUEvaluator
from utils.inference_boxes import (
test_net,
load_icdar2015_gt,
load_icdar2013_gt,
load_synthtext_gt,
)
from utils.util import copyStateDict
def save_result_synth(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
img = np.array(img)
img_copy = img.copy()
region = pre_output[0]
affinity = pre_output[1]
# make result file list
filename, file_ext = os.path.splitext(os.path.basename(img_file))
# draw bounding boxes for prediction, color green
for i, box in enumerate(pre_box):
poly = np.array(box).astype(np.int32).reshape((-1))
poly = poly.reshape(-1, 2)
try:
cv2.polylines(
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
)
except:
pass
# draw bounding boxes for gt, color red
if gt_box is not None:
for j in range(len(gt_box)):
cv2.polylines(
img,
[np.array(gt_box[j]["points"]).astype(np.int32).reshape((-1, 1, 2))],
True,
color=(0, 0, 255),
thickness=2,
)
# draw overlay image
overlay_img = overlay(img_copy, region, affinity, pre_box)
# Save result image
res_img_path = result_dir + "/res_" + filename + ".jpg"
cv2.imwrite(res_img_path, img)
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
cv2.imwrite(overlay_image_path, overlay_img)
def save_result_2015(img_file, img, pre_output, pre_box, gt_box, result_dir):
img = np.array(img)
img_copy = img.copy()
region = pre_output[0]
affinity = pre_output[1]
# make result file list
filename, file_ext = os.path.splitext(os.path.basename(img_file))
for i, box in enumerate(pre_box):
poly = np.array(box).astype(np.int32).reshape((-1))
poly = poly.reshape(-1, 2)
try:
cv2.polylines(
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
)
except:
pass
if gt_box is not None:
for j in range(len(gt_box)):
_gt_box = np.array(gt_box[j]["points"]).reshape(-1, 2).astype(np.int32)
if gt_box[j]["text"] == "###":
cv2.polylines(img, [_gt_box], True, color=(128, 128, 128), thickness=2)
else:
cv2.polylines(img, [_gt_box], True, color=(0, 0, 255), thickness=2)
# draw overlay image
overlay_img = overlay(img_copy, region, affinity, pre_box)
# Save result image
res_img_path = result_dir + "/res_" + filename + ".jpg"
cv2.imwrite(res_img_path, img)
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
cv2.imwrite(overlay_image_path, overlay_img)
def save_result_2013(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
img = np.array(img)
img_copy = img.copy()
region = pre_output[0]
affinity = pre_output[1]
# make result file list
filename, file_ext = os.path.splitext(os.path.basename(img_file))
# draw bounding boxes for prediction, color green
for i, box in enumerate(pre_box):
poly = np.array(box).astype(np.int32).reshape((-1))
poly = poly.reshape(-1, 2)
try:
cv2.polylines(
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
)
except:
pass
# draw bounding boxes for gt, color red
if gt_box is not None:
for j in range(len(gt_box)):
cv2.polylines(
img,
[np.array(gt_box[j]["points"]).reshape((-1, 1, 2))],
True,
color=(0, 0, 255),
thickness=2,
)
# draw overlay image
overlay_img = overlay(img_copy, region, affinity, pre_box)
# Save result image
res_img_path = result_dir + "/res_" + filename + ".jpg"
cv2.imwrite(res_img_path, img)
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
cv2.imwrite(overlay_image_path, overlay_img)
def overlay(image, region, affinity, single_img_bbox):
height, width, channel = image.shape
region_score = cv2.resize(region, (width, height))
affinity_score = cv2.resize(affinity, (width, height))
overlay_region = cv2.addWeighted(image.copy(), 0.4, region_score, 0.6, 5)
overlay_aff = cv2.addWeighted(image.copy(), 0.4, affinity_score, 0.6, 5)
boxed_img = image.copy()
for word_box in single_img_bbox:
cv2.polylines(
boxed_img,
[word_box.astype(np.int32).reshape((-1, 1, 2))],
True,
color=(0, 255, 0),
thickness=3,
)
temp1 = np.hstack([image, boxed_img])
temp2 = np.hstack([overlay_region, overlay_aff])
temp3 = np.vstack([temp1, temp2])
return temp3
def load_test_dataset_iou(test_folder_name, config):
if test_folder_name == "synthtext":
total_bboxes_gt, total_img_path = load_synthtext_gt(config.test_data_dir)
elif test_folder_name == "icdar2013":
total_bboxes_gt, total_img_path = load_icdar2013_gt(
dataFolder=config.test_data_dir
)
elif test_folder_name == "icdar2015":
total_bboxes_gt, total_img_path = load_icdar2015_gt(
dataFolder=config.test_data_dir
)
elif test_folder_name == "custom_data":
total_bboxes_gt, total_img_path = load_icdar2015_gt(
dataFolder=config.test_data_dir
)
else:
print("not found test dataset")
return None, None
return total_bboxes_gt, total_img_path
def viz_test(img, pre_output, pre_box, gt_box, img_name, result_dir, test_folder_name):
if test_folder_name == "synthtext":
save_result_synth(
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
)
elif test_folder_name == "icdar2013":
save_result_2013(
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
)
elif test_folder_name == "icdar2015":
save_result_2015(
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
)
elif test_folder_name == "custom_data":
save_result_2015(
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
)
else:
print("not found test dataset")
def main_eval(model_path, backbone, config, evaluator, result_dir, buffer, model, mode):
if not os.path.exists(result_dir):
os.makedirs(result_dir, exist_ok=True)
total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou("custom_data", config)
if mode == "weak_supervision" and torch.cuda.device_count() != 1:
gpu_count = torch.cuda.device_count() // 2
else:
gpu_count = torch.cuda.device_count()
gpu_idx = torch.cuda.current_device()
torch.cuda.set_device(gpu_idx)
# Only evaluation time
if model is None:
piece_imgs_path = total_imgs_path
if backbone == "vgg":
model = CRAFT()
else:
raise Exception("Undefined architecture")
print("Loading weights from checkpoint (" + model_path + ")")
net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
model.load_state_dict(copyStateDict(net_param["craft"]))
if config.cuda:
model = model.cuda()
cudnn.benchmark = False
# Distributed evaluation in the middle of training time
else:
if buffer is not None:
# check all buffer value is None for distributed evaluation
assert all(
v is None for v in buffer
), "Buffer already filled with another value."
slice_idx = len(total_imgs_bboxes_gt) // gpu_count
# last gpu
if gpu_idx == gpu_count - 1:
piece_imgs_path = total_imgs_path[gpu_idx * slice_idx :]
# piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
else:
piece_imgs_path = total_imgs_path[
gpu_idx * slice_idx : (gpu_idx + 1) * slice_idx
]
# piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]
model.eval()
# -----------------------------------------------------------------------------------------------------------------#
total_imgs_bboxes_pre = []
for k, img_path in enumerate(tqdm(piece_imgs_path)):
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
single_img_bbox = []
bboxes, polys, score_text = test_net(
model,
image,
config.text_threshold,
config.link_threshold,
config.low_text,
config.cuda,
config.poly,
config.canvas_size,
config.mag_ratio,
)
for box in bboxes:
box_info = {"points": box, "text": "###", "ignore": False}
single_img_bbox.append(box_info)
total_imgs_bboxes_pre.append(single_img_bbox)
# Distributed evaluation -------------------------------------------------------------------------------------#
if buffer is not None:
buffer[gpu_idx * slice_idx + k] = single_img_bbox
# print(sum([element is not None for element in buffer]))
# -------------------------------------------------------------------------------------------------------------#
if config.vis_opt:
viz_test(
image,
score_text,
pre_box=polys,
gt_box=total_imgs_bboxes_gt[k],
img_name=img_path,
result_dir=result_dir,
test_folder_name="custom_data",
)
# When distributed evaluation mode, wait until buffer is full filled
if buffer is not None:
while None in buffer:
continue
assert all(v is not None for v in buffer), "Buffer not filled"
total_imgs_bboxes_pre = buffer
results = []
for i, (gt, pred) in enumerate(zip(total_imgs_bboxes_gt, total_imgs_bboxes_pre)):
perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
results.append(perSampleMetrics_dict)
metrics = evaluator.combine_results(results)
print(metrics)
return metrics
def cal_eval(config, data, res_dir_name, opt, mode):
evaluator = DetectionIoUEvaluator()
test_config = DotDict(config.test[data])
res_dir = os.path.join(os.path.join("exp", args.yaml), "{}".format(res_dir_name))
if opt == "iou_eval":
main_eval(
config.test.trained_model,
config.train.backbone,
test_config,
evaluator,
res_dir,
buffer=None,
model=None,
mode=mode,
)
else:
print("Undefined evaluation")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval")
parser.add_argument(
"--yaml",
"--yaml_file_name",
default="custom_data_train",
type=str,
help="Load configuration",
)
args = parser.parse_args()
# load configure
config = load_yaml(args.yaml)
config = DotDict(config)
if config["wandb_opt"]:
wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml)
wandb.config.update(config)
val_result_dir_name = args.yaml
cal_eval(
config,
"custom_data",
val_result_dir_name + "-ic15-iou",
opt="iou_eval",
mode=None,
)