testing dataset
This commit is contained in:
4
trainer/craft/.gitignore
vendored
Normal file
4
trainer/craft/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
model/__pycache__/
|
||||
wandb/*
|
||||
vis_result/*
|
||||
105
trainer/craft/README.md
Normal file
105
trainer/craft/README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# CRAFT-train
|
||||
On the official CRAFT github, there are many people who want to train CRAFT models.
|
||||
|
||||
However, the training code is not published in the official CRAFT repository.
|
||||
|
||||
There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf)
|
||||
|
||||
The trained model with this code recorded a level of performance similar to that of the original paper.
|
||||
|
||||
```bash
|
||||
├── config
|
||||
│ ├── syn_train.yaml
|
||||
│ └── custom_data_train.yaml
|
||||
├── data
|
||||
│ ├── pseudo_label
|
||||
│ │ ├── make_charbox.py
|
||||
│ │ └── watershed.py
|
||||
│ ├── boxEnlarge.py
|
||||
│ ├── dataset.py
|
||||
│ ├── gaussian.py
|
||||
│ ├── imgaug.py
|
||||
│ └── imgproc.py
|
||||
├── loss
|
||||
│ └── mseloss.py
|
||||
├── metrics
|
||||
│ └── eval_det_iou.py
|
||||
├── model
|
||||
│ ├── craft.py
|
||||
│ └── vgg16_bn.py
|
||||
├── utils
|
||||
│ ├── craft_utils.py
|
||||
│ ├── inference_boxes.py
|
||||
│ └── utils.py
|
||||
├── trainSynth.py
|
||||
├── train.py
|
||||
├── train_distributed.py
|
||||
├── eval.py
|
||||
├── data_root_dir (place dataset folder here)
|
||||
└── exp (model and experiment result files will saved here)
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
Install using `pip`
|
||||
|
||||
``` bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
### Training
|
||||
1. Put your training, test data in the following format
|
||||
```
|
||||
└── data_root_dir (you can change root dir in yaml file)
|
||||
├── ch4_training_images
|
||||
│ ├── img_1.jpg
|
||||
│ └── img_2.jpg
|
||||
├── ch4_training_localization_transcription_gt
|
||||
│ ├── gt_img_1.txt
|
||||
│ └── gt_img_2.txt
|
||||
├── ch4_test_images
|
||||
│ ├── img_1.jpg
|
||||
│ └── img_2.jpg
|
||||
└── ch4_training_localization_transcription_gt
|
||||
├── gt_img_1.txt
|
||||
└── gt_img_2.txt
|
||||
```
|
||||
* localization_transcription_gt files format :
|
||||
```
|
||||
377,117,463,117,465,130,378,130,Genaxis Theatre
|
||||
493,115,519,115,519,131,493,131,[06]
|
||||
374,155,409,155,409,170,374,170,###
|
||||
```
|
||||
2. Write configuration in yaml format (example config files are provided in `config` folder.)
|
||||
* To speed up training time with multi-gpu, set num_worker > 0
|
||||
3. Put the yaml file in the config folder
|
||||
4. Run training script like below (If you have multi-gpu, run train_distributed.py)
|
||||
5. Then, experiment results will be saved to ```./exp/[yaml]``` by default.
|
||||
|
||||
* Step 1 : To train CRAFT with SynthText dataset from scratch
|
||||
* Note : This step is not necessary if you use <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">this pretrain</a> as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup.
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train
|
||||
```
|
||||
|
||||
* Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU
|
||||
CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU
|
||||
```
|
||||
|
||||
### Arguments
|
||||
* ```--yaml``` : configuration file name
|
||||
|
||||
### Evaluation
|
||||
* In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75.
|
||||
* In the official paper, it is stated that the result F1-score of the second row setting is 0.87.
|
||||
* If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856.
|
||||
* It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti.
|
||||
* Half of GPU assigned for training, and half of GPU assigned for supervision setting.
|
||||
|
||||
| Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model |
|
||||
| ------------- |-----|:-----:|:-----:|:-----:|-----:|
|
||||
| SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">download link</a>|
|
||||
| SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| <a href="https://drive.google.com/file/d/1qUeZIDSFCOuGS9yo8o0fi-zYHLEW6lBP/view">download link</a>|
|
||||
0
trainer/craft/config/__init__.py
Normal file
0
trainer/craft/config/__init__.py
Normal file
100
trainer/craft/config/custom_data_train.yaml
Normal file
100
trainer/craft/config/custom_data_train.yaml
Normal file
@@ -0,0 +1,100 @@
|
||||
wandb_opt: False
|
||||
|
||||
results_dir: "./exp/"
|
||||
vis_test_dir: "./vis_result/"
|
||||
|
||||
data_root_dir: "./data_root_dir/"
|
||||
score_gt_dir: None # "/data/ICDAR2015_official_supervision"
|
||||
mode: "weak_supervision"
|
||||
|
||||
|
||||
train:
|
||||
backbone : vgg
|
||||
use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option
|
||||
synth_data_dir: "/data/SynthText/"
|
||||
synth_ratio: 5
|
||||
real_dataset: custom
|
||||
ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth"
|
||||
eval_interval: 1000
|
||||
batch_size: 5
|
||||
st_iter: 0
|
||||
end_iter: 25000
|
||||
lr: 0.0001
|
||||
lr_decay: 7500
|
||||
gamma: 0.2
|
||||
weight_decay: 0.00001
|
||||
num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up
|
||||
amp: True
|
||||
loss: 2
|
||||
neg_rto: 0.3
|
||||
n_min_neg: 5000
|
||||
data:
|
||||
vis_opt: False
|
||||
pseudo_vis_opt: False
|
||||
output_size: 768
|
||||
do_not_care_label: ['###', '']
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
variance: [0.229, 0.224, 0.225]
|
||||
enlarge_region : [0.5, 0.5] # x axis, y axis
|
||||
enlarge_affinity: [0.5, 0.5]
|
||||
gauss_init_size: 200
|
||||
gauss_sigma: 40
|
||||
watershed:
|
||||
version: "skimage"
|
||||
sure_fg_th: 0.75
|
||||
sure_bg_th: 0.05
|
||||
syn_sample: -1
|
||||
custom_sample: -1
|
||||
syn_aug:
|
||||
random_scale:
|
||||
range: [1.0, 1.5, 2.0]
|
||||
option: False
|
||||
random_rotate:
|
||||
max_angle: 20
|
||||
option: False
|
||||
random_crop:
|
||||
version: "random_resize_crop_synth"
|
||||
option: True
|
||||
random_horizontal_flip:
|
||||
option: False
|
||||
random_colorjitter:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.2
|
||||
hue: 0.2
|
||||
option: True
|
||||
custom_aug:
|
||||
random_scale:
|
||||
range: [ 1.0, 1.5, 2.0 ]
|
||||
option: False
|
||||
random_rotate:
|
||||
max_angle: 20
|
||||
option: True
|
||||
random_crop:
|
||||
version: "random_resize_crop"
|
||||
scale: [0.03, 0.4]
|
||||
ratio: [0.75, 1.33]
|
||||
rnd_threshold: 1.0
|
||||
option: True
|
||||
random_horizontal_flip:
|
||||
option: True
|
||||
random_colorjitter:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.2
|
||||
hue: 0.2
|
||||
option: True
|
||||
|
||||
test:
|
||||
trained_model : null
|
||||
custom_data:
|
||||
test_set_size: 500
|
||||
test_data_dir: "./data_root_dir/"
|
||||
text_threshold: 0.75
|
||||
low_text: 0.5
|
||||
link_threshold: 0.2
|
||||
canvas_size: 2240
|
||||
mag_ratio: 1.75
|
||||
poly: False
|
||||
cuda: True
|
||||
vis_opt: False
|
||||
37
trainer/craft/config/load_config.py
Normal file
37
trainer/craft/config/load_config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import yaml
|
||||
from functools import reduce
|
||||
|
||||
CONFIG_PATH = os.path.dirname(__file__)
|
||||
|
||||
def load_yaml(config_name):
|
||||
|
||||
with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
return config
|
||||
|
||||
class DotDict(dict):
|
||||
def __getattr__(self, k):
|
||||
try:
|
||||
v = self[k]
|
||||
except:
|
||||
return super().__getattr__(k)
|
||||
if isinstance(v, dict):
|
||||
return DotDict(v)
|
||||
return v
|
||||
|
||||
def __getitem__(self, k):
|
||||
if isinstance(k, str) and '.' in k:
|
||||
k = k.split('.')
|
||||
if isinstance(k, (list, tuple)):
|
||||
return reduce(lambda d, kk: d[kk], k, self)
|
||||
return super().__getitem__(k)
|
||||
|
||||
def get(self, k, default=None):
|
||||
if isinstance(k, str) and '.' in k:
|
||||
try:
|
||||
return self[k]
|
||||
except KeyError:
|
||||
return default
|
||||
return super().get(k, default=default)
|
||||
68
trainer/craft/config/syn_train.yaml
Normal file
68
trainer/craft/config/syn_train.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
wandb_opt: False
|
||||
|
||||
results_dir: "./exp/"
|
||||
vis_test_dir: "./vis_result/"
|
||||
data_dir:
|
||||
synthtext: "/data/SynthText/"
|
||||
synthtext_gt: NULL
|
||||
|
||||
train:
|
||||
backbone : vgg
|
||||
dataset: ["synthtext"]
|
||||
ckpt_path: null
|
||||
eval_interval: 1000
|
||||
batch_size: 5
|
||||
st_iter: 0
|
||||
end_iter: 50000
|
||||
lr: 0.0001
|
||||
lr_decay: 15000
|
||||
gamma: 0.2
|
||||
weight_decay: 0.00001
|
||||
num_workers: 4
|
||||
amp: True
|
||||
loss: 3
|
||||
neg_rto: 1
|
||||
n_min_neg: 1000
|
||||
data:
|
||||
vis_opt: False
|
||||
output_size: 768
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
variance: [0.229, 0.224, 0.225]
|
||||
enlarge_region : [0.5, 0.5] # x axis, y axis
|
||||
enlarge_affinity: [0.5, 0.5]
|
||||
gauss_init_size: 200
|
||||
gauss_sigma: 40
|
||||
syn_sample : -1
|
||||
syn_aug:
|
||||
random_scale:
|
||||
range: [1.0, 1.5, 2.0]
|
||||
option: False
|
||||
random_rotate:
|
||||
max_angle: 20
|
||||
option: False
|
||||
random_crop:
|
||||
version: "random_resize_crop_synth"
|
||||
rnd_threshold : 1.0
|
||||
option: True
|
||||
random_horizontal_flip:
|
||||
option: False
|
||||
random_colorjitter:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.2
|
||||
hue: 0.2
|
||||
option: True
|
||||
|
||||
test:
|
||||
trained_model: null
|
||||
icdar2013:
|
||||
test_set_size: 233
|
||||
cuda: True
|
||||
vis_opt: True
|
||||
test_data_dir : "/data/ICDAR2013/"
|
||||
text_threshold: 0.85
|
||||
low_text: 0.5
|
||||
link_threshold: 0.2
|
||||
canvas_size: 960
|
||||
mag_ratio: 1.5
|
||||
poly: False
|
||||
65
trainer/craft/data/boxEnlarge.py
Normal file
65
trainer/craft/data/boxEnlarge.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pointAngle(Apoint, Bpoint):
|
||||
angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8)
|
||||
return angle
|
||||
|
||||
def pointDistance(Apoint, Bpoint):
|
||||
return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2)
|
||||
|
||||
def lineBiasAndK(Apoint, Bpoint):
|
||||
|
||||
K = pointAngle(Apoint, Bpoint)
|
||||
B = Apoint[1] - K*Apoint[0]
|
||||
return K, B
|
||||
|
||||
def getX(K, B, Ypoint):
|
||||
return int((Ypoint-B)/K)
|
||||
|
||||
def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size):
|
||||
|
||||
K, B = lineBiasAndK(Apoint, Bpoint)
|
||||
angle = abs(math.atan(pointAngle(Apoint, Bpoint)))
|
||||
distance = pointDistance(Apoint, Bpoint)
|
||||
|
||||
x_enlarge_size, y_enlarge_size = enlarge_size
|
||||
|
||||
XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance)
|
||||
YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance)
|
||||
|
||||
if placehold == 'leftTop':
|
||||
x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
|
||||
y1 = max(0, Apoint[1] - YaxisIncreaseDistance)
|
||||
elif placehold == 'rightTop':
|
||||
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
|
||||
y1 = max(0, Bpoint[1] - YaxisIncreaseDistance)
|
||||
elif placehold == 'rightBottom':
|
||||
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
|
||||
y1 = min(h, Bpoint[1] + YaxisIncreaseDistance)
|
||||
elif placehold == 'leftBottom':
|
||||
x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
|
||||
y1 = min(h, Apoint[1] + YaxisIncreaseDistance)
|
||||
return int(x1), int(y1)
|
||||
|
||||
def enlargebox(box, h, w, enlarge_size, horizontal_text_bool):
|
||||
|
||||
if not horizontal_text_bool:
|
||||
enlarge_size = (enlarge_size[1], enlarge_size[0])
|
||||
|
||||
box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0)
|
||||
|
||||
Apoint, Bpoint, Cpoint, Dpoint = box
|
||||
K1, B1 = lineBiasAndK(box[0], box[2])
|
||||
K2, B2 = lineBiasAndK(box[3], box[1])
|
||||
X = (B2 - B1)/(K1 - K2)
|
||||
Y = K1 * X + B1
|
||||
center = [X, Y]
|
||||
|
||||
x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size)
|
||||
x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size)
|
||||
x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size)
|
||||
x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size)
|
||||
newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
|
||||
return newcharbox
|
||||
542
trainer/craft/data/dataset.py
Normal file
542
trainer/craft/data/dataset.py
Normal file
@@ -0,0 +1,542 @@
|
||||
import os
|
||||
import re
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import scipy.io as scio
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from data import imgproc
|
||||
from data.gaussian import GaussianBuilder
|
||||
from data.imgaug import (
|
||||
rescale,
|
||||
random_resize_crop_synth,
|
||||
random_resize_crop,
|
||||
random_horizontal_flip,
|
||||
random_rotate,
|
||||
random_scale,
|
||||
random_crop,
|
||||
)
|
||||
from data.pseudo_label.make_charbox import PseudoCharBoxBuilder
|
||||
from utils.util import saveInput, saveImage
|
||||
|
||||
|
||||
class CraftBaseDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
output_size,
|
||||
data_dir,
|
||||
saved_gt_dir,
|
||||
mean,
|
||||
variance,
|
||||
gauss_init_size,
|
||||
gauss_sigma,
|
||||
enlarge_region,
|
||||
enlarge_affinity,
|
||||
aug,
|
||||
vis_test_dir,
|
||||
vis_opt,
|
||||
sample,
|
||||
):
|
||||
self.output_size = output_size
|
||||
self.data_dir = data_dir
|
||||
self.saved_gt_dir = saved_gt_dir
|
||||
self.mean, self.variance = mean, variance
|
||||
self.gaussian_builder = GaussianBuilder(
|
||||
gauss_init_size, gauss_sigma, enlarge_region, enlarge_affinity
|
||||
)
|
||||
self.aug = aug
|
||||
self.vis_test_dir = vis_test_dir
|
||||
self.vis_opt = vis_opt
|
||||
self.sample = sample
|
||||
if self.sample != -1:
|
||||
random.seed(0)
|
||||
self.idx = random.sample(range(0, len(self.img_names)), self.sample)
|
||||
|
||||
self.pre_crop_area = []
|
||||
|
||||
def augment_image(
|
||||
self, image, region_score, affinity_score, confidence_mask, word_level_char_bbox
|
||||
):
|
||||
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
||||
|
||||
if self.aug.random_scale.option:
|
||||
augment_targets, word_level_char_bbox = random_scale(
|
||||
augment_targets, word_level_char_bbox, self.aug.random_scale.range
|
||||
)
|
||||
|
||||
if self.aug.random_rotate.option:
|
||||
augment_targets = random_rotate(
|
||||
augment_targets, self.aug.random_rotate.max_angle
|
||||
)
|
||||
|
||||
if self.aug.random_crop.option:
|
||||
if self.aug.random_crop.version == "random_crop_with_bbox":
|
||||
augment_targets = random_crop_with_bbox(
|
||||
augment_targets, word_level_char_bbox, self.output_size
|
||||
)
|
||||
elif self.aug.random_crop.version == "random_resize_crop_synth":
|
||||
augment_targets = random_resize_crop_synth(
|
||||
augment_targets, self.output_size
|
||||
)
|
||||
elif self.aug.random_crop.version == "random_resize_crop":
|
||||
|
||||
if len(self.pre_crop_area) > 0:
|
||||
pre_crop_area = self.pre_crop_area
|
||||
else:
|
||||
pre_crop_area = None
|
||||
|
||||
augment_targets = random_resize_crop(
|
||||
augment_targets,
|
||||
self.aug.random_crop.scale,
|
||||
self.aug.random_crop.ratio,
|
||||
self.output_size,
|
||||
self.aug.random_crop.rnd_threshold,
|
||||
pre_crop_area,
|
||||
)
|
||||
|
||||
elif self.aug.random_crop.version == "random_crop":
|
||||
augment_targets = random_crop(augment_targets, self.output_size,)
|
||||
|
||||
else:
|
||||
assert "Undefined RandomCrop version"
|
||||
|
||||
if self.aug.random_horizontal_flip.option:
|
||||
augment_targets = random_horizontal_flip(augment_targets)
|
||||
|
||||
if self.aug.random_colorjitter.option:
|
||||
image, region_score, affinity_score, confidence_mask = augment_targets
|
||||
image = Image.fromarray(image)
|
||||
image = transforms.ColorJitter(
|
||||
brightness=self.aug.random_colorjitter.brightness,
|
||||
contrast=self.aug.random_colorjitter.contrast,
|
||||
saturation=self.aug.random_colorjitter.saturation,
|
||||
hue=self.aug.random_colorjitter.hue,
|
||||
)(image)
|
||||
else:
|
||||
image, region_score, affinity_score, confidence_mask = augment_targets
|
||||
|
||||
return np.array(image), region_score, affinity_score, confidence_mask
|
||||
|
||||
def resize_to_half(self, ground_truth, interpolation):
|
||||
return cv2.resize(
|
||||
ground_truth,
|
||||
(self.output_size // 2, self.output_size // 2),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
if self.sample != -1:
|
||||
return len(self.idx)
|
||||
else:
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.sample != -1:
|
||||
index = self.idx[index]
|
||||
if self.saved_gt_dir is None:
|
||||
(
|
||||
image,
|
||||
region_score,
|
||||
affinity_score,
|
||||
confidence_mask,
|
||||
word_level_char_bbox,
|
||||
all_affinity_bbox,
|
||||
words,
|
||||
) = self.make_gt_score(index)
|
||||
else:
|
||||
(
|
||||
image,
|
||||
region_score,
|
||||
affinity_score,
|
||||
confidence_mask,
|
||||
word_level_char_bbox,
|
||||
words,
|
||||
) = self.load_saved_gt_score(index)
|
||||
all_affinity_bbox = []
|
||||
|
||||
if self.vis_opt:
|
||||
saveImage(
|
||||
self.img_names[index],
|
||||
self.vis_test_dir,
|
||||
image.copy(),
|
||||
word_level_char_bbox.copy(),
|
||||
all_affinity_bbox.copy(),
|
||||
region_score.copy(),
|
||||
affinity_score.copy(),
|
||||
confidence_mask.copy(),
|
||||
)
|
||||
|
||||
image, region_score, affinity_score, confidence_mask = self.augment_image(
|
||||
image, region_score, affinity_score, confidence_mask, word_level_char_bbox
|
||||
)
|
||||
|
||||
if self.vis_opt:
|
||||
saveInput(
|
||||
self.img_names[index],
|
||||
self.vis_test_dir,
|
||||
image,
|
||||
region_score,
|
||||
affinity_score,
|
||||
confidence_mask,
|
||||
)
|
||||
|
||||
region_score = self.resize_to_half(region_score, interpolation=cv2.INTER_CUBIC)
|
||||
affinity_score = self.resize_to_half(
|
||||
affinity_score, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
confidence_mask = self.resize_to_half(
|
||||
confidence_mask, interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
image = imgproc.normalizeMeanVariance(
|
||||
np.array(image), mean=self.mean, variance=self.variance
|
||||
)
|
||||
image = image.transpose(2, 0, 1)
|
||||
|
||||
return image, region_score, affinity_score, confidence_mask
|
||||
|
||||
|
||||
class SynthTextDataSet(CraftBaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
output_size,
|
||||
data_dir,
|
||||
saved_gt_dir,
|
||||
mean,
|
||||
variance,
|
||||
gauss_init_size,
|
||||
gauss_sigma,
|
||||
enlarge_region,
|
||||
enlarge_affinity,
|
||||
aug,
|
||||
vis_test_dir,
|
||||
vis_opt,
|
||||
sample,
|
||||
):
|
||||
super().__init__(
|
||||
output_size,
|
||||
data_dir,
|
||||
saved_gt_dir,
|
||||
mean,
|
||||
variance,
|
||||
gauss_init_size,
|
||||
gauss_sigma,
|
||||
enlarge_region,
|
||||
enlarge_affinity,
|
||||
aug,
|
||||
vis_test_dir,
|
||||
vis_opt,
|
||||
sample,
|
||||
)
|
||||
self.img_names, self.char_bbox, self.img_words = self.load_data()
|
||||
self.vis_index = list(range(1000))
|
||||
|
||||
def load_data(self, bbox="char"):
|
||||
|
||||
gt = scio.loadmat(os.path.join(self.data_dir, "gt.mat"))
|
||||
img_names = gt["imnames"][0]
|
||||
img_words = gt["txt"][0]
|
||||
|
||||
if bbox == "char":
|
||||
img_bbox = gt["charBB"][0]
|
||||
else:
|
||||
img_bbox = gt["wordBB"][0] # word bbox needed for test
|
||||
|
||||
return img_names, img_bbox, img_words
|
||||
|
||||
def dilate_img_to_output_size(self, image, char_bbox):
|
||||
h, w, _ = image.shape
|
||||
if min(h, w) <= self.output_size:
|
||||
scale = float(self.output_size) / min(h, w)
|
||||
else:
|
||||
scale = 1.0
|
||||
image = cv2.resize(
|
||||
image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
char_bbox *= scale
|
||||
return image, char_bbox
|
||||
|
||||
def make_gt_score(self, index):
|
||||
img_path = os.path.join(self.data_dir, self.img_names[index][0])
|
||||
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
all_char_bbox = self.char_bbox[index].transpose(
|
||||
(2, 1, 0)
|
||||
) # shape : (Number of characters in image, 4, 2)
|
||||
|
||||
img_h, img_w, _ = image.shape
|
||||
|
||||
confidence_mask = np.ones((img_h, img_w), dtype=np.float32)
|
||||
|
||||
words = [
|
||||
re.split(" \n|\n |\n| ", word.strip()) for word in self.img_words[index]
|
||||
]
|
||||
words = list(itertools.chain(*words))
|
||||
words = [word for word in words if len(word) > 0]
|
||||
|
||||
word_level_char_bbox = []
|
||||
char_idx = 0
|
||||
|
||||
for i in range(len(words)):
|
||||
length_of_word = len(words[i])
|
||||
word_bbox = all_char_bbox[char_idx : char_idx + length_of_word]
|
||||
assert len(word_bbox) == length_of_word
|
||||
char_idx += length_of_word
|
||||
word_bbox = np.array(word_bbox)
|
||||
word_level_char_bbox.append(word_bbox)
|
||||
|
||||
region_score = self.gaussian_builder.generate_region(
|
||||
img_h,
|
||||
img_w,
|
||||
word_level_char_bbox,
|
||||
horizontal_text_bools=[True for _ in range(len(words))],
|
||||
)
|
||||
affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
|
||||
img_h,
|
||||
img_w,
|
||||
word_level_char_bbox,
|
||||
horizontal_text_bools=[True for _ in range(len(words))],
|
||||
)
|
||||
|
||||
return (
|
||||
image,
|
||||
region_score,
|
||||
affinity_score,
|
||||
confidence_mask,
|
||||
word_level_char_bbox,
|
||||
all_affinity_bbox,
|
||||
words,
|
||||
)
|
||||
|
||||
|
||||
class CustomDataset(CraftBaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
output_size,
|
||||
data_dir,
|
||||
saved_gt_dir,
|
||||
mean,
|
||||
variance,
|
||||
gauss_init_size,
|
||||
gauss_sigma,
|
||||
enlarge_region,
|
||||
enlarge_affinity,
|
||||
aug,
|
||||
vis_test_dir,
|
||||
vis_opt,
|
||||
sample,
|
||||
watershed_param,
|
||||
pseudo_vis_opt,
|
||||
do_not_care_label,
|
||||
):
|
||||
super().__init__(
|
||||
output_size,
|
||||
data_dir,
|
||||
saved_gt_dir,
|
||||
mean,
|
||||
variance,
|
||||
gauss_init_size,
|
||||
gauss_sigma,
|
||||
enlarge_region,
|
||||
enlarge_affinity,
|
||||
aug,
|
||||
vis_test_dir,
|
||||
vis_opt,
|
||||
sample,
|
||||
)
|
||||
self.pseudo_vis_opt = pseudo_vis_opt
|
||||
self.do_not_care_label = do_not_care_label
|
||||
self.pseudo_charbox_builder = PseudoCharBoxBuilder(
|
||||
watershed_param, vis_test_dir, pseudo_vis_opt, self.gaussian_builder
|
||||
)
|
||||
self.vis_index = list(range(1000))
|
||||
self.img_dir = os.path.join(data_dir, "ch4_training_images")
|
||||
self.img_gt_box_dir = os.path.join(
|
||||
data_dir, "ch4_training_localization_transcription_gt"
|
||||
)
|
||||
self.img_names = os.listdir(self.img_dir)
|
||||
|
||||
def update_model(self, net):
|
||||
self.net = net
|
||||
|
||||
def update_device(self, gpu):
|
||||
self.gpu = gpu
|
||||
|
||||
def load_img_gt_box(self, img_gt_box_path):
|
||||
lines = open(img_gt_box_path, encoding="utf-8").readlines()
|
||||
word_bboxes = []
|
||||
words = []
|
||||
for line in lines:
|
||||
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
||||
box_points = [int(box_info[i]) for i in range(8)]
|
||||
box_points = np.array(box_points, np.float32).reshape(4, 2)
|
||||
word = box_info[8:]
|
||||
word = ",".join(word)
|
||||
if word in self.do_not_care_label:
|
||||
words.append(self.do_not_care_label[0])
|
||||
word_bboxes.append(box_points)
|
||||
continue
|
||||
word_bboxes.append(box_points)
|
||||
words.append(word)
|
||||
return np.array(word_bboxes), words
|
||||
|
||||
def load_data(self, index):
|
||||
img_name = self.img_names[index]
|
||||
img_path = os.path.join(self.img_dir, img_name)
|
||||
image = cv2.imread(img_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img_gt_box_path = os.path.join(
|
||||
self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
|
||||
)
|
||||
word_bboxes, words = self.load_img_gt_box(
|
||||
img_gt_box_path
|
||||
) # shape : (Number of word bbox, 4, 2)
|
||||
confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32)
|
||||
|
||||
word_level_char_bbox = []
|
||||
do_care_words = []
|
||||
horizontal_text_bools = []
|
||||
|
||||
if len(word_bboxes) == 0:
|
||||
return (
|
||||
image,
|
||||
word_level_char_bbox,
|
||||
do_care_words,
|
||||
confidence_mask,
|
||||
horizontal_text_bools,
|
||||
)
|
||||
_word_bboxes = word_bboxes.copy()
|
||||
for i in range(len(word_bboxes)):
|
||||
if words[i] in self.do_not_care_label:
|
||||
cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], 0)
|
||||
continue
|
||||
|
||||
(
|
||||
pseudo_char_bbox,
|
||||
confidence,
|
||||
horizontal_text_bool,
|
||||
) = self.pseudo_charbox_builder.build_char_box(
|
||||
self.net, self.gpu, image, word_bboxes[i], words[i], img_name=img_name
|
||||
)
|
||||
|
||||
cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], confidence)
|
||||
do_care_words.append(words[i])
|
||||
word_level_char_bbox.append(pseudo_char_bbox)
|
||||
horizontal_text_bools.append(horizontal_text_bool)
|
||||
|
||||
return (
|
||||
image,
|
||||
word_level_char_bbox,
|
||||
do_care_words,
|
||||
confidence_mask,
|
||||
horizontal_text_bools,
|
||||
)
|
||||
|
||||
def make_gt_score(self, index):
|
||||
"""
|
||||
Make region, affinity scores using pseudo character-level GT bounding box
|
||||
word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
|
||||
:rtype region_score: np.float32
|
||||
:rtype affinity_score: np.float32
|
||||
:rtype confidence_mask: np.float32
|
||||
:rtype word_level_char_bbox: np.float32
|
||||
:rtype words: list
|
||||
"""
|
||||
(
|
||||
image,
|
||||
word_level_char_bbox,
|
||||
words,
|
||||
confidence_mask,
|
||||
horizontal_text_bools,
|
||||
) = self.load_data(index)
|
||||
img_h, img_w, _ = image.shape
|
||||
|
||||
if len(word_level_char_bbox) == 0:
|
||||
region_score = np.zeros((img_h, img_w), dtype=np.float32)
|
||||
affinity_score = np.zeros((img_h, img_w), dtype=np.float32)
|
||||
all_affinity_bbox = []
|
||||
else:
|
||||
region_score = self.gaussian_builder.generate_region(
|
||||
img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
||||
)
|
||||
affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
|
||||
img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
||||
)
|
||||
|
||||
return (
|
||||
image,
|
||||
region_score,
|
||||
affinity_score,
|
||||
confidence_mask,
|
||||
word_level_char_bbox,
|
||||
all_affinity_bbox,
|
||||
words,
|
||||
)
|
||||
|
||||
def load_saved_gt_score(self, index):
|
||||
"""
|
||||
Load pre-saved official CRAFT model's region, affinity scores to train
|
||||
word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
|
||||
:rtype region_score: np.float32
|
||||
:rtype affinity_score: np.float32
|
||||
:rtype confidence_mask: np.float32
|
||||
:rtype word_level_char_bbox: np.float32
|
||||
:rtype words: list
|
||||
"""
|
||||
img_name = self.img_names[index]
|
||||
img_path = os.path.join(self.img_dir, img_name)
|
||||
image = cv2.imread(img_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img_gt_box_path = os.path.join(
|
||||
self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
|
||||
)
|
||||
word_bboxes, words = self.load_img_gt_box(img_gt_box_path)
|
||||
image, word_bboxes = rescale(image, word_bboxes)
|
||||
img_h, img_w, _ = image.shape
|
||||
|
||||
query_idx = int(self.img_names[index].split(".")[0].split("_")[1])
|
||||
|
||||
saved_region_scores_path = os.path.join(
|
||||
self.saved_gt_dir, f"res_img_{query_idx}_region.jpg"
|
||||
)
|
||||
saved_affi_scores_path = os.path.join(
|
||||
self.saved_gt_dir, f"res_img_{query_idx}_affi.jpg"
|
||||
)
|
||||
saved_cf_mask_path = os.path.join(
|
||||
self.saved_gt_dir, f"res_img_{query_idx}_cf_mask_thresh_0.6.jpg"
|
||||
)
|
||||
region_score = cv2.imread(saved_region_scores_path, cv2.IMREAD_GRAYSCALE)
|
||||
affinity_score = cv2.imread(saved_affi_scores_path, cv2.IMREAD_GRAYSCALE)
|
||||
confidence_mask = cv2.imread(saved_cf_mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
region_score = cv2.resize(region_score, (img_w, img_h))
|
||||
affinity_score = cv2.resize(affinity_score, (img_w, img_h))
|
||||
confidence_mask = cv2.resize(
|
||||
confidence_mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
region_score = region_score.astype(np.float32) / 255
|
||||
affinity_score = affinity_score.astype(np.float32) / 255
|
||||
confidence_mask = confidence_mask.astype(np.float32) / 255
|
||||
|
||||
# NOTE : Even though word_level_char_bbox is not necessary, align bbox format with make_gt_score()
|
||||
word_level_char_bbox = []
|
||||
|
||||
for i in range(len(word_bboxes)):
|
||||
word_level_char_bbox.append(np.expand_dims(word_bboxes[i], 0))
|
||||
|
||||
return (
|
||||
image,
|
||||
region_score,
|
||||
affinity_score,
|
||||
confidence_mask,
|
||||
word_level_char_bbox,
|
||||
words,
|
||||
)
|
||||
192
trainer/craft/data/gaussian.py
Normal file
192
trainer/craft/data/gaussian.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from data.boxEnlarge import enlargebox
|
||||
|
||||
|
||||
class GaussianBuilder(object):
|
||||
def __init__(self, init_size, sigma, enlarge_region, enlarge_affinity):
|
||||
self.init_size = init_size
|
||||
self.sigma = sigma
|
||||
self.enlarge_region = enlarge_region
|
||||
self.enlarge_affinity = enlarge_affinity
|
||||
self.gaussian_map, self.gaussian_map_color = self.generate_gaussian_map()
|
||||
|
||||
def generate_gaussian_map(self):
|
||||
circle_mask = self.generate_circle_mask()
|
||||
|
||||
gaussian_map = np.zeros((self.init_size, self.init_size), np.float32)
|
||||
|
||||
for i in range(self.init_size):
|
||||
for j in range(self.init_size):
|
||||
gaussian_map[i, j] = (
|
||||
1
|
||||
/ 2
|
||||
/ np.pi
|
||||
/ (self.sigma ** 2)
|
||||
* np.exp(
|
||||
-1
|
||||
/ 2
|
||||
* (
|
||||
(i - self.init_size / 2) ** 2 / (self.sigma ** 2)
|
||||
+ (j - self.init_size / 2) ** 2 / (self.sigma ** 2)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
gaussian_map = gaussian_map * circle_mask
|
||||
gaussian_map = (gaussian_map / np.max(gaussian_map)).astype(np.float32)
|
||||
|
||||
gaussian_map_color = (gaussian_map * 255).astype(np.uint8)
|
||||
gaussian_map_color = cv2.applyColorMap(gaussian_map_color, cv2.COLORMAP_JET)
|
||||
return gaussian_map, gaussian_map_color
|
||||
|
||||
def generate_circle_mask(self):
|
||||
|
||||
zero_arr = np.zeros((self.init_size, self.init_size), np.float32)
|
||||
circle_mask = cv2.circle(
|
||||
img=zero_arr,
|
||||
center=(self.init_size // 2, self.init_size // 2),
|
||||
radius=self.init_size // 2,
|
||||
color=1,
|
||||
thickness=-1,
|
||||
)
|
||||
|
||||
return circle_mask
|
||||
|
||||
def four_point_transform(self, bbox):
|
||||
"""
|
||||
Using the bbox, standard 2D gaussian map, returns Transformed 2d Gaussian map
|
||||
"""
|
||||
width, height = (
|
||||
np.max(bbox[:, 0]).astype(np.int32),
|
||||
np.max(bbox[:, 1]).astype(np.int32),
|
||||
)
|
||||
init_points = np.array(
|
||||
[
|
||||
[0, 0],
|
||||
[self.init_size, 0],
|
||||
[self.init_size, self.init_size],
|
||||
[0, self.init_size],
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
M = cv2.getPerspectiveTransform(init_points, bbox)
|
||||
warped_gaussian_map = cv2.warpPerspective(self.gaussian_map, M, (width, height))
|
||||
return warped_gaussian_map, width, height
|
||||
|
||||
def add_gaussian_map_to_score_map(
|
||||
self, score_map, bbox, enlarge_size, horizontal_text_bool, map_type=None
|
||||
):
|
||||
"""
|
||||
Mapping 2D Gaussian to the character box coordinates of the score_map.
|
||||
|
||||
:param score_map: Target map to put 2D gaussian on character box
|
||||
:type score_map: np.float32
|
||||
:param bbox: character boxes
|
||||
:type bbox: np.float32
|
||||
:param enlarge_size: Enlarge size of gaussian map to fit character shape
|
||||
:type enlarge_size: list of enlarge size [x dim, y dim]
|
||||
:param horizontal_text_bool: Flag that bbox is horizontal text or not
|
||||
:type horizontal_text_bool: bool
|
||||
:param map_type: Whether map's type is "region" | "affinity"
|
||||
:type map_type: str
|
||||
:return score_map: score map that all 2D gaussian put on character box
|
||||
:rtype: np.float32
|
||||
"""
|
||||
|
||||
map_h, map_w = score_map.shape
|
||||
bbox = enlargebox(bbox, map_h, map_w, enlarge_size, horizontal_text_bool)
|
||||
|
||||
# If any one point of character bbox is out of range, don't put in on map
|
||||
if np.any(bbox < 0) or np.any(bbox[:, 0] > map_w) or np.any(bbox[:, 1] > map_h):
|
||||
return score_map
|
||||
|
||||
bbox_left, bbox_top = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(
|
||||
np.int32
|
||||
)
|
||||
bbox -= (bbox_left, bbox_top)
|
||||
warped_gaussian_map, width, height = self.four_point_transform(
|
||||
bbox.astype(np.float32)
|
||||
)
|
||||
|
||||
try:
|
||||
bbox_area_of_image = score_map[
|
||||
bbox_top : bbox_top + height, bbox_left : bbox_left + width,
|
||||
]
|
||||
high_value_score = np.where(
|
||||
warped_gaussian_map > bbox_area_of_image,
|
||||
warped_gaussian_map,
|
||||
bbox_area_of_image,
|
||||
)
|
||||
score_map[
|
||||
bbox_top : bbox_top + height, bbox_left : bbox_left + width,
|
||||
] = high_value_score
|
||||
|
||||
except Exception as e:
|
||||
print("Error : {}".format(e))
|
||||
print(
|
||||
"On generating {} map, strange box came out. (width: {}, height: {})".format(
|
||||
map_type, width, height
|
||||
)
|
||||
)
|
||||
|
||||
return score_map
|
||||
|
||||
def calculate_affinity_box_points(self, bbox_1, bbox_2, vertical=False):
|
||||
center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
|
||||
if vertical:
|
||||
tl = (bbox_1[0] + bbox_1[-1] + center_1) / 3
|
||||
tr = (bbox_1[1:3].sum(0) + center_1) / 3
|
||||
br = (bbox_2[1:3].sum(0) + center_2) / 3
|
||||
bl = (bbox_2[0] + bbox_2[-1] + center_2) / 3
|
||||
else:
|
||||
tl = (bbox_1[0:2].sum(0) + center_1) / 3
|
||||
tr = (bbox_2[0:2].sum(0) + center_2) / 3
|
||||
br = (bbox_2[2:4].sum(0) + center_2) / 3
|
||||
bl = (bbox_1[2:4].sum(0) + center_1) / 3
|
||||
affinity_box = np.array([tl, tr, br, bl]).astype(np.float32)
|
||||
return affinity_box
|
||||
|
||||
def generate_region(
|
||||
self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
||||
):
|
||||
region_map = np.zeros([img_h, img_w], dtype=np.float32)
|
||||
for i in range(
|
||||
len(word_level_char_bbox)
|
||||
): # shape : [word_num, [char_num_in_one_word, 4, 2]]
|
||||
for j in range(len(word_level_char_bbox[i])):
|
||||
region_map = self.add_gaussian_map_to_score_map(
|
||||
region_map,
|
||||
word_level_char_bbox[i][j].copy(),
|
||||
self.enlarge_region,
|
||||
horizontal_text_bools[i],
|
||||
map_type="region",
|
||||
)
|
||||
return region_map
|
||||
|
||||
def generate_affinity(
|
||||
self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
||||
):
|
||||
|
||||
affinity_map = np.zeros([img_h, img_w], dtype=np.float32)
|
||||
all_affinity_bbox = []
|
||||
for i in range(len(word_level_char_bbox)):
|
||||
for j in range(len(word_level_char_bbox[i]) - 1):
|
||||
affinity_bbox = self.calculate_affinity_box_points(
|
||||
word_level_char_bbox[i][j], word_level_char_bbox[i][j + 1]
|
||||
)
|
||||
|
||||
affinity_map = self.add_gaussian_map_to_score_map(
|
||||
affinity_map,
|
||||
affinity_bbox.copy(),
|
||||
self.enlarge_affinity,
|
||||
horizontal_text_bools[i],
|
||||
map_type="affinity",
|
||||
)
|
||||
all_affinity_bbox.append(np.expand_dims(affinity_bbox, axis=0))
|
||||
|
||||
if len(all_affinity_bbox) > 0:
|
||||
all_affinity_bbox = np.concatenate(all_affinity_bbox, axis=0)
|
||||
return affinity_map, all_affinity_bbox
|
||||
175
trainer/craft/data/imgaug.py
Normal file
175
trainer/craft/data/imgaug.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resized_crop, crop
|
||||
from torchvision.transforms import RandomResizedCrop, RandomCrop
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
|
||||
def rescale(img, bboxes, target_size=2240):
|
||||
h, w = img.shape[0:2]
|
||||
scale = target_size / max(h, w)
|
||||
img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
||||
bboxes = bboxes * scale
|
||||
return img, bboxes
|
||||
|
||||
|
||||
def random_resize_crop_synth(augment_targets, size):
|
||||
image, region_score, affinity_score, confidence_mask = augment_targets
|
||||
|
||||
image = Image.fromarray(image)
|
||||
region_score = Image.fromarray(region_score)
|
||||
affinity_score = Image.fromarray(affinity_score)
|
||||
confidence_mask = Image.fromarray(confidence_mask)
|
||||
|
||||
short_side = min(image.size)
|
||||
i, j, h, w = RandomCrop.get_params(image, output_size=(short_side, short_side))
|
||||
|
||||
image = resized_crop(
|
||||
image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
|
||||
)
|
||||
region_score = resized_crop(
|
||||
region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
|
||||
)
|
||||
affinity_score = resized_crop(
|
||||
affinity_score,
|
||||
i,
|
||||
j,
|
||||
h,
|
||||
w,
|
||||
(size, size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
confidence_mask = resized_crop(
|
||||
confidence_mask,
|
||||
i,
|
||||
j,
|
||||
h,
|
||||
w,
|
||||
(size, size),
|
||||
interpolation=InterpolationMode.NEAREST,
|
||||
)
|
||||
|
||||
image = np.array(image)
|
||||
region_score = np.array(region_score)
|
||||
affinity_score = np.array(affinity_score)
|
||||
confidence_mask = np.array(confidence_mask)
|
||||
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
||||
|
||||
return augment_targets
|
||||
|
||||
|
||||
def random_resize_crop(
|
||||
augment_targets, scale, ratio, size, threshold, pre_crop_area=None
|
||||
):
|
||||
image, region_score, affinity_score, confidence_mask = augment_targets
|
||||
|
||||
image = Image.fromarray(image)
|
||||
region_score = Image.fromarray(region_score)
|
||||
affinity_score = Image.fromarray(affinity_score)
|
||||
confidence_mask = Image.fromarray(confidence_mask)
|
||||
|
||||
if pre_crop_area != None:
|
||||
i, j, h, w = pre_crop_area
|
||||
|
||||
else:
|
||||
if random.random() < threshold:
|
||||
i, j, h, w = RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
|
||||
else:
|
||||
i, j, h, w = RandomResizedCrop.get_params(
|
||||
image, scale=(1.0, 1.0), ratio=(1.0, 1.0)
|
||||
)
|
||||
|
||||
image = resized_crop(
|
||||
image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
|
||||
)
|
||||
region_score = resized_crop(
|
||||
region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
|
||||
)
|
||||
affinity_score = resized_crop(
|
||||
affinity_score,
|
||||
i,
|
||||
j,
|
||||
h,
|
||||
w,
|
||||
(size, size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
confidence_mask = resized_crop(
|
||||
confidence_mask,
|
||||
i,
|
||||
j,
|
||||
h,
|
||||
w,
|
||||
(size, size),
|
||||
interpolation=InterpolationMode.NEAREST,
|
||||
)
|
||||
|
||||
image = np.array(image)
|
||||
region_score = np.array(region_score)
|
||||
affinity_score = np.array(affinity_score)
|
||||
confidence_mask = np.array(confidence_mask)
|
||||
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
||||
|
||||
return augment_targets
|
||||
|
||||
|
||||
def random_crop(augment_targets, size):
|
||||
image, region_score, affinity_score, confidence_mask = augment_targets
|
||||
|
||||
image = Image.fromarray(image)
|
||||
region_score = Image.fromarray(region_score)
|
||||
affinity_score = Image.fromarray(affinity_score)
|
||||
confidence_mask = Image.fromarray(confidence_mask)
|
||||
|
||||
i, j, h, w = RandomCrop.get_params(image, output_size=(size, size))
|
||||
|
||||
image = crop(image, i, j, h, w)
|
||||
region_score = crop(region_score, i, j, h, w)
|
||||
affinity_score = crop(affinity_score, i, j, h, w)
|
||||
confidence_mask = crop(confidence_mask, i, j, h, w)
|
||||
|
||||
image = np.array(image)
|
||||
region_score = np.array(region_score)
|
||||
affinity_score = np.array(affinity_score)
|
||||
confidence_mask = np.array(confidence_mask)
|
||||
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
||||
|
||||
return augment_targets
|
||||
|
||||
|
||||
def random_horizontal_flip(imgs):
|
||||
if random.random() < 0.5:
|
||||
for i in range(len(imgs)):
|
||||
imgs[i] = np.flip(imgs[i], axis=1).copy()
|
||||
return imgs
|
||||
|
||||
|
||||
def random_scale(images, word_level_char_bbox, scale_range):
|
||||
scale = random.sample(scale_range, 1)[0]
|
||||
|
||||
for i in range(len(images)):
|
||||
images[i] = cv2.resize(images[i], dsize=None, fx=scale, fy=scale)
|
||||
|
||||
for i in range(len(word_level_char_bbox)):
|
||||
word_level_char_bbox[i] *= scale
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def random_rotate(images, max_angle):
|
||||
angle = random.random() * 2 * max_angle - max_angle
|
||||
for i in range(len(images)):
|
||||
img = images[i]
|
||||
w, h = img.shape[:2]
|
||||
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
|
||||
if i == len(images) - 1:
|
||||
img_rotation = cv2.warpAffine(
|
||||
img, M=rotation_matrix, dsize=(h, w), flags=cv2.INTER_NEAREST
|
||||
)
|
||||
else:
|
||||
img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
|
||||
images[i] = img_rotation
|
||||
return images
|
||||
91
trainer/craft/data/imgproc.py
Normal file
91
trainer/craft/data/imgproc.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Copyright (c) 2019-present NAVER Corp.
|
||||
MIT License
|
||||
"""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
from skimage import io
|
||||
|
||||
|
||||
def loadImage(img_file):
|
||||
img = io.imread(img_file) # RGB order
|
||||
if img.shape[0] == 2:
|
||||
img = img[0]
|
||||
if len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||
if img.shape[2] == 4:
|
||||
img = img[:, :, :3]
|
||||
img = np.array(img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def normalizeMeanVariance(
|
||||
in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
|
||||
):
|
||||
# should be RGB order
|
||||
img = in_img.copy().astype(np.float32)
|
||||
|
||||
img -= np.array(
|
||||
[mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
|
||||
)
|
||||
img /= np.array(
|
||||
[variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
|
||||
dtype=np.float32,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def denormalizeMeanVariance(
|
||||
in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
|
||||
):
|
||||
# should be RGB order
|
||||
img = in_img.copy()
|
||||
img *= variance
|
||||
img += mean
|
||||
img *= 255.0
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
|
||||
height, width, channel = img.shape
|
||||
|
||||
# magnify image size
|
||||
target_size = mag_ratio * max(height, width)
|
||||
|
||||
# set original image size
|
||||
if target_size > square_size:
|
||||
target_size = square_size
|
||||
|
||||
ratio = target_size / max(height, width)
|
||||
|
||||
target_h, target_w = int(height * ratio), int(width * ratio)
|
||||
|
||||
# NOTE
|
||||
valid_size_heatmap = (int(target_h / 2), int(target_w / 2))
|
||||
|
||||
proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
|
||||
|
||||
# make canvas and paste image
|
||||
target_h32, target_w32 = target_h, target_w
|
||||
if target_h % 32 != 0:
|
||||
target_h32 = target_h + (32 - target_h % 32)
|
||||
if target_w % 32 != 0:
|
||||
target_w32 = target_w + (32 - target_w % 32)
|
||||
resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
|
||||
resized[0:target_h, 0:target_w, :] = proc
|
||||
|
||||
# target_h, target_w = target_h32, target_w32
|
||||
# size_heatmap = (int(target_w/2), int(target_h/2))
|
||||
|
||||
return resized, ratio, valid_size_heatmap
|
||||
|
||||
|
||||
def cvt2HeatmapImg(img):
|
||||
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
|
||||
img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
|
||||
return img
|
||||
263
trainer/craft/data/pseudo_label/make_charbox.py
Normal file
263
trainer/craft/data/pseudo_label/make_charbox.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import os
|
||||
import random
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from data import imgproc
|
||||
from data.pseudo_label.watershed import exec_watershed_by_version
|
||||
|
||||
|
||||
class PseudoCharBoxBuilder:
|
||||
def __init__(self, watershed_param, vis_test_dir, pseudo_vis_opt, gaussian_builder):
|
||||
self.watershed_param = watershed_param
|
||||
self.vis_test_dir = vis_test_dir
|
||||
self.pseudo_vis_opt = pseudo_vis_opt
|
||||
self.gaussian_builder = gaussian_builder
|
||||
self.cnt = 0
|
||||
self.flag = False
|
||||
|
||||
def crop_image_by_bbox(self, image, box, word):
|
||||
w = max(
|
||||
int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[2] - box[3]))
|
||||
)
|
||||
h = max(
|
||||
int(np.linalg.norm(box[0] - box[3])), int(np.linalg.norm(box[1] - box[2]))
|
||||
)
|
||||
try:
|
||||
word_ratio = h / w
|
||||
except:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
one_char_ratio = min(h, w) / (max(h, w) / len(word))
|
||||
|
||||
# NOTE: criterion to split vertical word in here is set to work properly on IC15 dataset
|
||||
if word_ratio > 2 or (word_ratio > 1.6 and one_char_ratio > 2.4):
|
||||
# warping method of vertical word (classified by upper condition)
|
||||
horizontal_text_bool = False
|
||||
long_side = h
|
||||
short_side = w
|
||||
M = cv2.getPerspectiveTransform(
|
||||
np.float32(box),
|
||||
np.float32(
|
||||
np.array(
|
||||
[
|
||||
[long_side, 0],
|
||||
[long_side, short_side],
|
||||
[0, short_side],
|
||||
[0, 0],
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
self.flag = True
|
||||
else:
|
||||
# warping method of horizontal word
|
||||
horizontal_text_bool = True
|
||||
long_side = w
|
||||
short_side = h
|
||||
M = cv2.getPerspectiveTransform(
|
||||
np.float32(box),
|
||||
np.float32(
|
||||
np.array(
|
||||
[
|
||||
[0, 0],
|
||||
[long_side, 0],
|
||||
[long_side, short_side],
|
||||
[0, short_side],
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
self.flag = False
|
||||
|
||||
warped = cv2.warpPerspective(image, M, (long_side, short_side))
|
||||
return warped, M, horizontal_text_bool
|
||||
|
||||
def inference_word_box(self, net, gpu, word_image):
|
||||
if net.training:
|
||||
net.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
word_img_torch = torch.from_numpy(
|
||||
imgproc.normalizeMeanVariance(
|
||||
word_image,
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
variance=(0.229, 0.224, 0.225),
|
||||
)
|
||||
)
|
||||
word_img_torch = word_img_torch.permute(2, 0, 1).unsqueeze(0)
|
||||
word_img_torch = word_img_torch.type(torch.FloatTensor).cuda(gpu)
|
||||
with torch.cuda.amp.autocast():
|
||||
word_img_scores, _ = net(word_img_torch)
|
||||
return word_img_scores
|
||||
|
||||
def visualize_pseudo_label(
|
||||
self, word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
|
||||
):
|
||||
word_img_h, word_img_w, _ = word_image.shape
|
||||
word_img_cp1 = word_image.copy()
|
||||
word_img_cp2 = word_image.copy()
|
||||
_watershed_box = np.int32(watershed_box)
|
||||
_pseudo_char_bbox = np.int32(pseudo_char_bbox)
|
||||
|
||||
region_score_color = cv2.applyColorMap(np.uint8(region_score), cv2.COLORMAP_JET)
|
||||
region_score_color = cv2.resize(region_score_color, (word_img_w, word_img_h))
|
||||
|
||||
for box in _watershed_box:
|
||||
cv2.polylines(
|
||||
np.uint8(word_img_cp1),
|
||||
[np.reshape(box, (-1, 1, 2))],
|
||||
True,
|
||||
(255, 0, 0),
|
||||
)
|
||||
|
||||
for box in _pseudo_char_bbox:
|
||||
cv2.polylines(
|
||||
np.uint8(word_img_cp2), [np.reshape(box, (-1, 1, 2))], True, (255, 0, 0)
|
||||
)
|
||||
|
||||
# NOTE: Just for visualize, put gaussian map on char box
|
||||
pseudo_gt_region_score = self.gaussian_builder.generate_region(
|
||||
word_img_h, word_img_w, [_pseudo_char_bbox], [True]
|
||||
)
|
||||
|
||||
pseudo_gt_region_score = cv2.applyColorMap(
|
||||
(pseudo_gt_region_score * 255).astype("uint8"), cv2.COLORMAP_JET
|
||||
)
|
||||
|
||||
overlay_img = cv2.addWeighted(
|
||||
word_image[:, :, ::-1], 0.7, pseudo_gt_region_score, 0.3, 5
|
||||
)
|
||||
vis_result = np.hstack(
|
||||
[
|
||||
word_image[:, :, ::-1],
|
||||
region_score_color,
|
||||
word_img_cp1[:, :, ::-1],
|
||||
word_img_cp2[:, :, ::-1],
|
||||
pseudo_gt_region_score,
|
||||
overlay_img,
|
||||
]
|
||||
)
|
||||
|
||||
if not os.path.exists(os.path.dirname(self.vis_test_dir)):
|
||||
os.makedirs(os.path.dirname(self.vis_test_dir))
|
||||
cv2.imwrite(
|
||||
os.path.join(
|
||||
self.vis_test_dir,
|
||||
"{}_{}".format(
|
||||
img_name, f"pseudo_char_bbox_{random.randint(0,100)}.jpg"
|
||||
),
|
||||
),
|
||||
vis_result,
|
||||
)
|
||||
|
||||
def clip_into_boundary(self, box, bound):
|
||||
if len(box) == 0:
|
||||
return box
|
||||
else:
|
||||
box[:, :, 0] = np.clip(box[:, :, 0], 0, bound[1])
|
||||
box[:, :, 1] = np.clip(box[:, :, 1], 0, bound[0])
|
||||
return box
|
||||
|
||||
def get_confidence(self, real_len, pseudo_len):
|
||||
if pseudo_len == 0:
|
||||
return 0.0
|
||||
return (real_len - min(real_len, abs(real_len - pseudo_len))) / real_len
|
||||
|
||||
def split_word_equal_gap(self, word_img_w, word_img_h, word):
|
||||
width = word_img_w
|
||||
height = word_img_h
|
||||
|
||||
width_per_char = width / len(word)
|
||||
bboxes = []
|
||||
for j, char in enumerate(word):
|
||||
if char == " ":
|
||||
continue
|
||||
left = j * width_per_char
|
||||
right = (j + 1) * width_per_char
|
||||
bbox = np.array([[left, 0], [right, 0], [right, height], [left, height]])
|
||||
bboxes.append(bbox)
|
||||
|
||||
bboxes = np.array(bboxes, np.float32)
|
||||
return bboxes
|
||||
|
||||
def cal_angle(self, v1):
|
||||
theta = np.arccos(min(1, v1[0] / (np.linalg.norm(v1) + 10e-8)))
|
||||
return 2 * math.pi - theta if v1[1] < 0 else theta
|
||||
|
||||
def clockwise_sort(self, points):
|
||||
# returns 4x2 [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] ndarray
|
||||
v1, v2, v3, v4 = points
|
||||
center = (v1 + v2 + v3 + v4) / 4
|
||||
theta = np.array(
|
||||
[
|
||||
self.cal_angle(v1 - center),
|
||||
self.cal_angle(v2 - center),
|
||||
self.cal_angle(v3 - center),
|
||||
self.cal_angle(v4 - center),
|
||||
]
|
||||
)
|
||||
index = np.argsort(theta)
|
||||
return np.array([v1, v2, v3, v4])[index, :]
|
||||
|
||||
def build_char_box(self, net, gpu, image, word_bbox, word, img_name=""):
|
||||
word_image, M, horizontal_text_bool = self.crop_image_by_bbox(
|
||||
image, word_bbox, word
|
||||
)
|
||||
real_word_without_space = word.replace("\s", "")
|
||||
real_char_len = len(real_word_without_space)
|
||||
|
||||
scale = 128.0 / word_image.shape[0]
|
||||
|
||||
word_image = cv2.resize(word_image, None, fx=scale, fy=scale)
|
||||
word_img_h, word_img_w, _ = word_image.shape
|
||||
|
||||
scores = self.inference_word_box(net, gpu, word_image)
|
||||
region_score = scores[0, :, :, 0].cpu().data.numpy()
|
||||
region_score = np.uint8(np.clip(region_score, 0, 1) * 255)
|
||||
|
||||
region_score_rgb = cv2.resize(region_score, (word_img_w, word_img_h))
|
||||
region_score_rgb = cv2.cvtColor(region_score_rgb, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
pseudo_char_bbox = exec_watershed_by_version(
|
||||
self.watershed_param, region_score, word_image, self.pseudo_vis_opt
|
||||
)
|
||||
|
||||
# Used for visualize only
|
||||
watershed_box = pseudo_char_bbox.copy()
|
||||
|
||||
pseudo_char_bbox = self.clip_into_boundary(
|
||||
pseudo_char_bbox, region_score_rgb.shape
|
||||
)
|
||||
|
||||
confidence = self.get_confidence(real_char_len, len(pseudo_char_bbox))
|
||||
|
||||
if confidence <= 0.5:
|
||||
pseudo_char_bbox = self.split_word_equal_gap(word_img_w, word_img_h, word)
|
||||
confidence = 0.5
|
||||
|
||||
if self.pseudo_vis_opt and self.flag:
|
||||
self.visualize_pseudo_label(
|
||||
word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
|
||||
)
|
||||
|
||||
if len(pseudo_char_bbox) != 0:
|
||||
index = np.argsort(pseudo_char_bbox[:, 0, 0])
|
||||
pseudo_char_bbox = pseudo_char_bbox[index]
|
||||
|
||||
pseudo_char_bbox /= scale
|
||||
|
||||
M_inv = np.linalg.pinv(M)
|
||||
for i in range(len(pseudo_char_bbox)):
|
||||
pseudo_char_bbox[i] = cv2.perspectiveTransform(
|
||||
pseudo_char_bbox[i][None, :, :], M_inv
|
||||
)
|
||||
|
||||
pseudo_char_bbox = self.clip_into_boundary(pseudo_char_bbox, image.shape)
|
||||
|
||||
return pseudo_char_bbox, confidence, horizontal_text_bool
|
||||
45
trainer/craft/data/pseudo_label/watershed.py
Normal file
45
trainer/craft/data/pseudo_label/watershed.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.segmentation import watershed
|
||||
|
||||
|
||||
def segment_region_score(watershed_param, region_score, word_image, pseudo_vis_opt):
|
||||
region_score = np.float32(region_score) / 255
|
||||
fore = np.uint8(region_score > 0.75)
|
||||
back = np.uint8(region_score < 0.05)
|
||||
unknown = 1 - (fore + back)
|
||||
ret, markers = cv2.connectedComponents(fore)
|
||||
markers += 1
|
||||
markers[unknown == 1] = 0
|
||||
|
||||
labels = watershed(-region_score, markers)
|
||||
boxes = []
|
||||
for label in range(2, ret + 1):
|
||||
y, x = np.where(labels == label)
|
||||
x_max = x.max()
|
||||
y_max = y.max()
|
||||
x_min = x.min()
|
||||
y_min = y.min()
|
||||
box = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
|
||||
box = np.array(box)
|
||||
box *= 2
|
||||
boxes.append(box)
|
||||
return np.array(boxes, dtype=np.float32)
|
||||
|
||||
|
||||
def exec_watershed_by_version(
|
||||
watershed_param, region_score, word_image, pseudo_vis_opt
|
||||
):
|
||||
|
||||
func_name_map_dict = {
|
||||
"skimage": segment_region_score,
|
||||
}
|
||||
|
||||
try:
|
||||
return func_name_map_dict[watershed_param.version](
|
||||
watershed_param, region_score, word_image, pseudo_vis_opt
|
||||
)
|
||||
except:
|
||||
print(
|
||||
f"Watershed version {watershed_param.version} does not exist in func_name_map_dict."
|
||||
)
|
||||
1
trainer/craft/data_root_dir/folder.txt
Normal file
1
trainer/craft/data_root_dir/folder.txt
Normal file
@@ -0,0 +1 @@
|
||||
place dataset folder here
|
||||
381
trainer/craft/eval.py
Normal file
381
trainer/craft/eval.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from tqdm import tqdm
|
||||
import wandb
|
||||
|
||||
from config.load_config import load_yaml, DotDict
|
||||
from model.craft import CRAFT
|
||||
from metrics.eval_det_iou import DetectionIoUEvaluator
|
||||
from utils.inference_boxes import (
|
||||
test_net,
|
||||
load_icdar2015_gt,
|
||||
load_icdar2013_gt,
|
||||
load_synthtext_gt,
|
||||
)
|
||||
from utils.util import copyStateDict
|
||||
|
||||
|
||||
|
||||
def save_result_synth(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
|
||||
|
||||
img = np.array(img)
|
||||
img_copy = img.copy()
|
||||
region = pre_output[0]
|
||||
affinity = pre_output[1]
|
||||
|
||||
# make result file list
|
||||
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
||||
|
||||
# draw bounding boxes for prediction, color green
|
||||
for i, box in enumerate(pre_box):
|
||||
poly = np.array(box).astype(np.int32).reshape((-1))
|
||||
poly = poly.reshape(-1, 2)
|
||||
try:
|
||||
cv2.polylines(
|
||||
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# draw bounding boxes for gt, color red
|
||||
if gt_box is not None:
|
||||
for j in range(len(gt_box)):
|
||||
cv2.polylines(
|
||||
img,
|
||||
[np.array(gt_box[j]["points"]).astype(np.int32).reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=(0, 0, 255),
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
# draw overlay image
|
||||
overlay_img = overlay(img_copy, region, affinity, pre_box)
|
||||
|
||||
# Save result image
|
||||
res_img_path = result_dir + "/res_" + filename + ".jpg"
|
||||
cv2.imwrite(res_img_path, img)
|
||||
|
||||
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
|
||||
cv2.imwrite(overlay_image_path, overlay_img)
|
||||
|
||||
|
||||
def save_result_2015(img_file, img, pre_output, pre_box, gt_box, result_dir):
|
||||
|
||||
img = np.array(img)
|
||||
img_copy = img.copy()
|
||||
region = pre_output[0]
|
||||
affinity = pre_output[1]
|
||||
|
||||
# make result file list
|
||||
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
||||
|
||||
for i, box in enumerate(pre_box):
|
||||
poly = np.array(box).astype(np.int32).reshape((-1))
|
||||
poly = poly.reshape(-1, 2)
|
||||
try:
|
||||
cv2.polylines(
|
||||
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
if gt_box is not None:
|
||||
for j in range(len(gt_box)):
|
||||
_gt_box = np.array(gt_box[j]["points"]).reshape(-1, 2).astype(np.int32)
|
||||
if gt_box[j]["text"] == "###":
|
||||
cv2.polylines(img, [_gt_box], True, color=(128, 128, 128), thickness=2)
|
||||
else:
|
||||
cv2.polylines(img, [_gt_box], True, color=(0, 0, 255), thickness=2)
|
||||
|
||||
# draw overlay image
|
||||
overlay_img = overlay(img_copy, region, affinity, pre_box)
|
||||
|
||||
# Save result image
|
||||
res_img_path = result_dir + "/res_" + filename + ".jpg"
|
||||
cv2.imwrite(res_img_path, img)
|
||||
|
||||
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
|
||||
cv2.imwrite(overlay_image_path, overlay_img)
|
||||
|
||||
|
||||
def save_result_2013(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
|
||||
|
||||
img = np.array(img)
|
||||
img_copy = img.copy()
|
||||
region = pre_output[0]
|
||||
affinity = pre_output[1]
|
||||
|
||||
# make result file list
|
||||
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
||||
|
||||
# draw bounding boxes for prediction, color green
|
||||
for i, box in enumerate(pre_box):
|
||||
poly = np.array(box).astype(np.int32).reshape((-1))
|
||||
poly = poly.reshape(-1, 2)
|
||||
try:
|
||||
cv2.polylines(
|
||||
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# draw bounding boxes for gt, color red
|
||||
if gt_box is not None:
|
||||
for j in range(len(gt_box)):
|
||||
cv2.polylines(
|
||||
img,
|
||||
[np.array(gt_box[j]["points"]).reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=(0, 0, 255),
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
# draw overlay image
|
||||
overlay_img = overlay(img_copy, region, affinity, pre_box)
|
||||
|
||||
# Save result image
|
||||
res_img_path = result_dir + "/res_" + filename + ".jpg"
|
||||
cv2.imwrite(res_img_path, img)
|
||||
|
||||
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
|
||||
cv2.imwrite(overlay_image_path, overlay_img)
|
||||
|
||||
|
||||
def overlay(image, region, affinity, single_img_bbox):
|
||||
|
||||
height, width, channel = image.shape
|
||||
|
||||
region_score = cv2.resize(region, (width, height))
|
||||
affinity_score = cv2.resize(affinity, (width, height))
|
||||
|
||||
overlay_region = cv2.addWeighted(image.copy(), 0.4, region_score, 0.6, 5)
|
||||
overlay_aff = cv2.addWeighted(image.copy(), 0.4, affinity_score, 0.6, 5)
|
||||
|
||||
boxed_img = image.copy()
|
||||
for word_box in single_img_bbox:
|
||||
cv2.polylines(
|
||||
boxed_img,
|
||||
[word_box.astype(np.int32).reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=(0, 255, 0),
|
||||
thickness=3,
|
||||
)
|
||||
|
||||
temp1 = np.hstack([image, boxed_img])
|
||||
temp2 = np.hstack([overlay_region, overlay_aff])
|
||||
temp3 = np.vstack([temp1, temp2])
|
||||
|
||||
return temp3
|
||||
|
||||
|
||||
def load_test_dataset_iou(test_folder_name, config):
|
||||
|
||||
if test_folder_name == "synthtext":
|
||||
total_bboxes_gt, total_img_path = load_synthtext_gt(config.test_data_dir)
|
||||
|
||||
elif test_folder_name == "icdar2013":
|
||||
total_bboxes_gt, total_img_path = load_icdar2013_gt(
|
||||
dataFolder=config.test_data_dir
|
||||
)
|
||||
|
||||
elif test_folder_name == "icdar2015":
|
||||
total_bboxes_gt, total_img_path = load_icdar2015_gt(
|
||||
dataFolder=config.test_data_dir
|
||||
)
|
||||
|
||||
elif test_folder_name == "custom_data":
|
||||
total_bboxes_gt, total_img_path = load_icdar2015_gt(
|
||||
dataFolder=config.test_data_dir
|
||||
)
|
||||
|
||||
else:
|
||||
print("not found test dataset")
|
||||
return None, None
|
||||
|
||||
return total_bboxes_gt, total_img_path
|
||||
|
||||
|
||||
def viz_test(img, pre_output, pre_box, gt_box, img_name, result_dir, test_folder_name):
|
||||
|
||||
if test_folder_name == "synthtext":
|
||||
save_result_synth(
|
||||
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
||||
)
|
||||
elif test_folder_name == "icdar2013":
|
||||
save_result_2013(
|
||||
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
||||
)
|
||||
elif test_folder_name == "icdar2015":
|
||||
save_result_2015(
|
||||
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
||||
)
|
||||
elif test_folder_name == "custom_data":
|
||||
save_result_2015(
|
||||
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
||||
)
|
||||
else:
|
||||
print("not found test dataset")
|
||||
|
||||
|
||||
def main_eval(model_path, backbone, config, evaluator, result_dir, buffer, model, mode):
|
||||
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou("custom_data", config)
|
||||
|
||||
if mode == "weak_supervision" and torch.cuda.device_count() != 1:
|
||||
gpu_count = torch.cuda.device_count() // 2
|
||||
else:
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_idx = torch.cuda.current_device()
|
||||
torch.cuda.set_device(gpu_idx)
|
||||
|
||||
# Only evaluation time
|
||||
if model is None:
|
||||
piece_imgs_path = total_imgs_path
|
||||
|
||||
if backbone == "vgg":
|
||||
model = CRAFT()
|
||||
else:
|
||||
raise Exception("Undefined architecture")
|
||||
|
||||
print("Loading weights from checkpoint (" + model_path + ")")
|
||||
net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
|
||||
model.load_state_dict(copyStateDict(net_param["craft"]))
|
||||
|
||||
if config.cuda:
|
||||
model = model.cuda()
|
||||
cudnn.benchmark = False
|
||||
|
||||
# Distributed evaluation in the middle of training time
|
||||
else:
|
||||
if buffer is not None:
|
||||
# check all buffer value is None for distributed evaluation
|
||||
assert all(
|
||||
v is None for v in buffer
|
||||
), "Buffer already filled with another value."
|
||||
slice_idx = len(total_imgs_bboxes_gt) // gpu_count
|
||||
|
||||
# last gpu
|
||||
if gpu_idx == gpu_count - 1:
|
||||
piece_imgs_path = total_imgs_path[gpu_idx * slice_idx :]
|
||||
# piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
|
||||
else:
|
||||
piece_imgs_path = total_imgs_path[
|
||||
gpu_idx * slice_idx : (gpu_idx + 1) * slice_idx
|
||||
]
|
||||
# piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]
|
||||
|
||||
model.eval()
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------#
|
||||
total_imgs_bboxes_pre = []
|
||||
for k, img_path in enumerate(tqdm(piece_imgs_path)):
|
||||
image = cv2.imread(img_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
single_img_bbox = []
|
||||
bboxes, polys, score_text = test_net(
|
||||
model,
|
||||
image,
|
||||
config.text_threshold,
|
||||
config.link_threshold,
|
||||
config.low_text,
|
||||
config.cuda,
|
||||
config.poly,
|
||||
config.canvas_size,
|
||||
config.mag_ratio,
|
||||
)
|
||||
|
||||
for box in bboxes:
|
||||
box_info = {"points": box, "text": "###", "ignore": False}
|
||||
single_img_bbox.append(box_info)
|
||||
total_imgs_bboxes_pre.append(single_img_bbox)
|
||||
# Distributed evaluation -------------------------------------------------------------------------------------#
|
||||
if buffer is not None:
|
||||
buffer[gpu_idx * slice_idx + k] = single_img_bbox
|
||||
# print(sum([element is not None for element in buffer]))
|
||||
# -------------------------------------------------------------------------------------------------------------#
|
||||
|
||||
if config.vis_opt:
|
||||
viz_test(
|
||||
image,
|
||||
score_text,
|
||||
pre_box=polys,
|
||||
gt_box=total_imgs_bboxes_gt[k],
|
||||
img_name=img_path,
|
||||
result_dir=result_dir,
|
||||
test_folder_name="custom_data",
|
||||
)
|
||||
|
||||
# When distributed evaluation mode, wait until buffer is full filled
|
||||
if buffer is not None:
|
||||
while None in buffer:
|
||||
continue
|
||||
assert all(v is not None for v in buffer), "Buffer not filled"
|
||||
total_imgs_bboxes_pre = buffer
|
||||
|
||||
results = []
|
||||
for i, (gt, pred) in enumerate(zip(total_imgs_bboxes_gt, total_imgs_bboxes_pre)):
|
||||
perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
|
||||
results.append(perSampleMetrics_dict)
|
||||
|
||||
metrics = evaluator.combine_results(results)
|
||||
print(metrics)
|
||||
return metrics
|
||||
|
||||
def cal_eval(config, data, res_dir_name, opt, mode):
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
test_config = DotDict(config.test[data])
|
||||
res_dir = os.path.join(os.path.join("exp", args.yaml), "{}".format(res_dir_name))
|
||||
|
||||
if opt == "iou_eval":
|
||||
main_eval(
|
||||
config.test.trained_model,
|
||||
config.train.backbone,
|
||||
test_config,
|
||||
evaluator,
|
||||
res_dir,
|
||||
buffer=None,
|
||||
model=None,
|
||||
mode=mode,
|
||||
)
|
||||
else:
|
||||
print("Undefined evaluation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval")
|
||||
parser.add_argument(
|
||||
"--yaml",
|
||||
"--yaml_file_name",
|
||||
default="custom_data_train",
|
||||
type=str,
|
||||
help="Load configuration",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load configure
|
||||
config = load_yaml(args.yaml)
|
||||
config = DotDict(config)
|
||||
|
||||
if config["wandb_opt"]:
|
||||
wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml)
|
||||
wandb.config.update(config)
|
||||
|
||||
val_result_dir_name = args.yaml
|
||||
cal_eval(
|
||||
config,
|
||||
"custom_data",
|
||||
val_result_dir_name + "-ic15-iou",
|
||||
opt="iou_eval",
|
||||
mode=None,
|
||||
)
|
||||
1
trainer/craft/exp/folder.txt
Normal file
1
trainer/craft/exp/folder.txt
Normal file
@@ -0,0 +1 @@
|
||||
trained model will be saved here
|
||||
172
trainer/craft/loss/mseloss.py
Normal file
172
trainer/craft/loss/mseloss.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Loss(nn.Module):
|
||||
def __init__(self):
|
||||
super(Loss, self).__init__()
|
||||
|
||||
def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map):
|
||||
loss = torch.mean(
|
||||
((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2))
|
||||
* conf_map
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class Maploss_v2(nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
super(Maploss_v2, self).__init__()
|
||||
|
||||
def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg):
|
||||
|
||||
# positive_loss
|
||||
positive_pixel = (label_score > 0.1).float()
|
||||
positive_pixel_number = torch.sum(positive_pixel)
|
||||
|
||||
positive_loss_region = pred_score * positive_pixel
|
||||
|
||||
# negative_loss
|
||||
negative_pixel = (label_score <= 0.1).float()
|
||||
negative_pixel_number = torch.sum(negative_pixel)
|
||||
negative_loss_region = pred_score * negative_pixel
|
||||
|
||||
if positive_pixel_number != 0:
|
||||
if negative_pixel_number < neg_rto * positive_pixel_number:
|
||||
negative_loss = (
|
||||
torch.sum(
|
||||
torch.topk(
|
||||
negative_loss_region.view(-1), n_min_neg, sorted=False
|
||||
)[0]
|
||||
)
|
||||
/ n_min_neg
|
||||
)
|
||||
else:
|
||||
negative_loss = torch.sum(
|
||||
torch.topk(
|
||||
negative_loss_region.view(-1),
|
||||
int(neg_rto * positive_pixel_number),
|
||||
sorted=False,
|
||||
)[0]
|
||||
) / (positive_pixel_number * neg_rto)
|
||||
positive_loss = torch.sum(positive_loss_region) / positive_pixel_number
|
||||
else:
|
||||
# only negative pixel
|
||||
negative_loss = (
|
||||
torch.sum(
|
||||
torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[
|
||||
0
|
||||
]
|
||||
)
|
||||
/ n_min_neg
|
||||
)
|
||||
positive_loss = 0.0
|
||||
total_loss = positive_loss + negative_loss
|
||||
return total_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
region_scores_label,
|
||||
affinity_socres_label,
|
||||
region_scores_pre,
|
||||
affinity_scores_pre,
|
||||
mask,
|
||||
neg_rto,
|
||||
n_min_neg,
|
||||
):
|
||||
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
|
||||
assert (
|
||||
region_scores_label.size() == region_scores_pre.size()
|
||||
and affinity_socres_label.size() == affinity_scores_pre.size()
|
||||
)
|
||||
loss1 = loss_fn(region_scores_pre, region_scores_label)
|
||||
loss2 = loss_fn(affinity_scores_pre, affinity_socres_label)
|
||||
|
||||
loss_region = torch.mul(loss1, mask)
|
||||
loss_affinity = torch.mul(loss2, mask)
|
||||
|
||||
char_loss = self.batch_image_loss(
|
||||
loss_region, region_scores_label, neg_rto, n_min_neg
|
||||
)
|
||||
affi_loss = self.batch_image_loss(
|
||||
loss_affinity, affinity_socres_label, neg_rto, n_min_neg
|
||||
)
|
||||
return char_loss + affi_loss
|
||||
|
||||
|
||||
class Maploss_v3(nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
super(Maploss_v3, self).__init__()
|
||||
|
||||
def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg):
|
||||
|
||||
batch_size = pre_loss.shape[0]
|
||||
|
||||
positive_loss, negative_loss = 0, 0
|
||||
for single_loss, single_label in zip(pre_loss, loss_label):
|
||||
|
||||
# positive_loss
|
||||
pos_pixel = (single_label >= 0.1).float()
|
||||
n_pos_pixel = torch.sum(pos_pixel)
|
||||
pos_loss_region = single_loss * pos_pixel
|
||||
positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12)
|
||||
|
||||
# negative_loss
|
||||
neg_pixel = (single_label < 0.1).float()
|
||||
n_neg_pixel = torch.sum(neg_pixel)
|
||||
neg_loss_region = single_loss * neg_pixel
|
||||
|
||||
if n_pos_pixel != 0:
|
||||
if n_neg_pixel < neg_rto * n_pos_pixel:
|
||||
negative_loss += torch.sum(neg_loss_region) / n_neg_pixel
|
||||
else:
|
||||
n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel)
|
||||
# n_hard_neg = neg_rto*n_pos_pixel
|
||||
negative_loss += (
|
||||
torch.sum(
|
||||
torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0]
|
||||
)
|
||||
/ n_hard_neg
|
||||
)
|
||||
else:
|
||||
# only negative pixel
|
||||
negative_loss += (
|
||||
torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0])
|
||||
/ n_min_neg
|
||||
)
|
||||
|
||||
total_loss = (positive_loss + negative_loss) / batch_size
|
||||
|
||||
return total_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
region_scores_label,
|
||||
affinity_scores_label,
|
||||
region_scores_pre,
|
||||
affinity_scores_pre,
|
||||
mask,
|
||||
neg_rto,
|
||||
n_min_neg,
|
||||
):
|
||||
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
|
||||
|
||||
assert (
|
||||
region_scores_label.size() == region_scores_pre.size()
|
||||
and affinity_scores_label.size() == affinity_scores_pre.size()
|
||||
)
|
||||
loss1 = loss_fn(region_scores_pre, region_scores_label)
|
||||
loss2 = loss_fn(affinity_scores_pre, affinity_scores_label)
|
||||
|
||||
loss_region = torch.mul(loss1, mask)
|
||||
loss_affinity = torch.mul(loss2, mask)
|
||||
char_loss = self.single_image_loss(
|
||||
loss_region, region_scores_label, neg_rto, n_min_neg
|
||||
)
|
||||
affi_loss = self.single_image_loss(
|
||||
loss_affinity, affinity_scores_label, neg_rto, n_min_neg
|
||||
)
|
||||
|
||||
return char_loss + affi_loss
|
||||
244
trainer/craft/metrics/eval_det_iou.py
Normal file
244
trainer/craft/metrics/eval_det_iou.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
"""
|
||||
cite from:
|
||||
PaddleOCR, github: https://github.com/PaddlePaddle/PaddleOCR
|
||||
PaddleOCR reference from :
|
||||
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
|
||||
"""
|
||||
|
||||
|
||||
class DetectionIoUEvaluator(object):
|
||||
def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
|
||||
self.iou_constraint = iou_constraint
|
||||
self.area_precision_constraint = area_precision_constraint
|
||||
|
||||
def evaluate_image(self, gt, pred):
|
||||
def get_union(pD, pG):
|
||||
return Polygon(pD).union(Polygon(pG)).area
|
||||
|
||||
def get_intersection_over_union(pD, pG):
|
||||
return get_intersection(pD, pG) / get_union(pD, pG)
|
||||
|
||||
def get_intersection(pD, pG):
|
||||
return Polygon(pD).intersection(Polygon(pG)).area
|
||||
|
||||
def compute_ap(confList, matchList, numGtCare):
|
||||
correct = 0
|
||||
AP = 0
|
||||
if len(confList) > 0:
|
||||
confList = np.array(confList)
|
||||
matchList = np.array(matchList)
|
||||
sorted_ind = np.argsort(-confList)
|
||||
confList = confList[sorted_ind]
|
||||
matchList = matchList[sorted_ind]
|
||||
for n in range(len(confList)):
|
||||
match = matchList[n]
|
||||
if match:
|
||||
correct += 1
|
||||
AP += float(correct) / (n + 1)
|
||||
|
||||
if numGtCare > 0:
|
||||
AP /= numGtCare
|
||||
|
||||
return AP
|
||||
|
||||
perSampleMetrics = {}
|
||||
|
||||
matchedSum = 0
|
||||
|
||||
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
||||
|
||||
numGlobalCareGt = 0
|
||||
numGlobalCareDet = 0
|
||||
|
||||
arrGlobalConfidences = []
|
||||
arrGlobalMatches = []
|
||||
|
||||
recall = 0
|
||||
precision = 0
|
||||
hmean = 0
|
||||
|
||||
detMatched = 0
|
||||
|
||||
iouMat = np.empty([1, 1])
|
||||
|
||||
gtPols = []
|
||||
detPols = []
|
||||
|
||||
gtPolPoints = []
|
||||
detPolPoints = []
|
||||
|
||||
# Array of Ground Truth Polygons' keys marked as don't Care
|
||||
gtDontCarePolsNum = []
|
||||
# Array of Detected Polygons' matched with a don't Care GT
|
||||
detDontCarePolsNum = []
|
||||
|
||||
pairs = []
|
||||
detMatchedNums = []
|
||||
|
||||
arrSampleConfidences = []
|
||||
arrSampleMatch = []
|
||||
|
||||
evaluationLog = ""
|
||||
|
||||
# print(len(gt))
|
||||
|
||||
for n in range(len(gt)):
|
||||
points = gt[n]['points']
|
||||
# transcription = gt[n]['text']
|
||||
dontCare = gt[n]['ignore']
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
try:
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
except:
|
||||
import ipdb;
|
||||
ipdb.set_trace()
|
||||
|
||||
#import ipdb;ipdb.set_trace()
|
||||
gtPol = points
|
||||
gtPols.append(gtPol)
|
||||
gtPolPoints.append(points)
|
||||
if dontCare:
|
||||
gtDontCarePolsNum.append(len(gtPols) - 1)
|
||||
|
||||
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
|
||||
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
|
||||
if len(gtDontCarePolsNum) > 0 else "\n")
|
||||
|
||||
for n in range(len(pred)):
|
||||
points = pred[n]['points']
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
detPol = points
|
||||
detPols.append(detPol)
|
||||
detPolPoints.append(points)
|
||||
if len(gtDontCarePolsNum) > 0:
|
||||
for dontCarePol in gtDontCarePolsNum:
|
||||
dontCarePol = gtPols[dontCarePol]
|
||||
intersected_area = get_intersection(dontCarePol, detPol)
|
||||
pdDimensions = Polygon(detPol).area
|
||||
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
||||
if (precision > self.area_precision_constraint):
|
||||
detDontCarePolsNum.append(len(detPols) - 1)
|
||||
break
|
||||
|
||||
evaluationLog += "DET polygons: " + str(len(detPols)) + (
|
||||
" (" + str(len(detDontCarePolsNum)) + " don't care)\n"
|
||||
if len(detDontCarePolsNum) > 0 else "\n")
|
||||
|
||||
if len(gtPols) > 0 and len(detPols) > 0:
|
||||
# Calculate IoU and precision matrices
|
||||
outputShape = [len(gtPols), len(detPols)]
|
||||
iouMat = np.empty(outputShape)
|
||||
gtRectMat = np.zeros(len(gtPols), np.int8)
|
||||
detRectMat = np.zeros(len(detPols), np.int8)
|
||||
for gtNum in range(len(gtPols)):
|
||||
for detNum in range(len(detPols)):
|
||||
pG = gtPols[gtNum]
|
||||
pD = detPols[detNum]
|
||||
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
|
||||
|
||||
for gtNum in range(len(gtPols)):
|
||||
for detNum in range(len(detPols)):
|
||||
if gtRectMat[gtNum] == 0 and detRectMat[
|
||||
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
|
||||
if iouMat[gtNum, detNum] > self.iou_constraint:
|
||||
gtRectMat[gtNum] = 1
|
||||
detRectMat[detNum] = 1
|
||||
detMatched += 1
|
||||
pairs.append({'gt': gtNum, 'det': detNum})
|
||||
detMatchedNums.append(detNum)
|
||||
evaluationLog += "Match GT #" + \
|
||||
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
||||
|
||||
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
||||
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
||||
if numGtCare == 0:
|
||||
recall = float(1)
|
||||
precision = float(0) if numDetCare > 0 else float(1)
|
||||
else:
|
||||
recall = float(detMatched) / numGtCare
|
||||
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
||||
|
||||
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
||||
precision * recall / (precision + recall)
|
||||
|
||||
matchedSum += detMatched
|
||||
numGlobalCareGt += numGtCare
|
||||
numGlobalCareDet += numDetCare
|
||||
|
||||
perSampleMetrics = {
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'hmean': hmean,
|
||||
'pairs': pairs,
|
||||
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
|
||||
'gtPolPoints': gtPolPoints,
|
||||
'detPolPoints': detPolPoints,
|
||||
'gtCare': numGtCare,
|
||||
'detCare': numDetCare,
|
||||
'gtDontCare': gtDontCarePolsNum,
|
||||
'detDontCare': detDontCarePolsNum,
|
||||
'detMatched': detMatched,
|
||||
'evaluationLog': evaluationLog
|
||||
}
|
||||
|
||||
return perSampleMetrics
|
||||
|
||||
def combine_results(self, results):
|
||||
numGlobalCareGt = 0
|
||||
numGlobalCareDet = 0
|
||||
matchedSum = 0
|
||||
for result in results:
|
||||
numGlobalCareGt += result['gtCare']
|
||||
numGlobalCareDet += result['detCare']
|
||||
matchedSum += result['detMatched']
|
||||
|
||||
methodRecall = 0 if numGlobalCareGt == 0 else float(
|
||||
matchedSum) / numGlobalCareGt
|
||||
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
||||
matchedSum) / numGlobalCareDet
|
||||
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
||||
methodRecall * methodPrecision / (
|
||||
methodRecall + methodPrecision)
|
||||
# print(methodRecall, methodPrecision, methodHmean)
|
||||
# sys.exit(-1)
|
||||
methodMetrics = {
|
||||
'precision': methodPrecision,
|
||||
'recall': methodRecall,
|
||||
'hmean': methodHmean
|
||||
}
|
||||
|
||||
return methodMetrics
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
gts = [[{
|
||||
'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
|
||||
'text': 1234,
|
||||
'ignore': False,
|
||||
}, {
|
||||
'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
|
||||
'text': 5678,
|
||||
'ignore': False,
|
||||
}]]
|
||||
preds = [[{
|
||||
'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
|
||||
'text': 123,
|
||||
'ignore': False,
|
||||
}]]
|
||||
results = []
|
||||
for gt, pred in zip(gts, preds):
|
||||
results.append(evaluator.evaluate_image(gt, pred))
|
||||
metrics = evaluator.combine_results(results)
|
||||
print(metrics)
|
||||
112
trainer/craft/model/craft.py
Normal file
112
trainer/craft/model/craft.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Copyright (c) 2019-present NAVER Corp.
|
||||
MIT License
|
||||
"""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.vgg16_bn import vgg16_bn, init_weights
|
||||
|
||||
class double_conv(nn.Module):
|
||||
def __init__(self, in_ch, mid_ch, out_ch):
|
||||
super(double_conv, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
|
||||
nn.BatchNorm2d(mid_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class CRAFT(nn.Module):
|
||||
def __init__(self, pretrained=True, freeze=False, amp=False):
|
||||
super(CRAFT, self).__init__()
|
||||
|
||||
self.amp = amp
|
||||
|
||||
""" Base network """
|
||||
self.basenet = vgg16_bn(pretrained, freeze)
|
||||
|
||||
""" U network """
|
||||
self.upconv1 = double_conv(1024, 512, 256)
|
||||
self.upconv2 = double_conv(512, 256, 128)
|
||||
self.upconv3 = double_conv(256, 128, 64)
|
||||
self.upconv4 = double_conv(128, 64, 32)
|
||||
|
||||
num_class = 2
|
||||
self.conv_cls = nn.Sequential(
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(16, num_class, kernel_size=1),
|
||||
)
|
||||
|
||||
init_weights(self.upconv1.modules())
|
||||
init_weights(self.upconv2.modules())
|
||||
init_weights(self.upconv3.modules())
|
||||
init_weights(self.upconv4.modules())
|
||||
init_weights(self.conv_cls.modules())
|
||||
|
||||
def forward(self, x):
|
||||
""" Base network """
|
||||
if self.amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
sources = self.basenet(x)
|
||||
|
||||
""" U network """
|
||||
y = torch.cat([sources[0], sources[1]], dim=1)
|
||||
y = self.upconv1(y)
|
||||
|
||||
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
||||
y = torch.cat([y, sources[2]], dim=1)
|
||||
y = self.upconv2(y)
|
||||
|
||||
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
||||
y = torch.cat([y, sources[3]], dim=1)
|
||||
y = self.upconv3(y)
|
||||
|
||||
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
||||
y = torch.cat([y, sources[4]], dim=1)
|
||||
feature = self.upconv4(y)
|
||||
|
||||
y = self.conv_cls(feature)
|
||||
|
||||
return y.permute(0,2,3,1), feature
|
||||
else:
|
||||
|
||||
sources = self.basenet(x)
|
||||
|
||||
""" U network """
|
||||
y = torch.cat([sources[0], sources[1]], dim=1)
|
||||
y = self.upconv1(y)
|
||||
|
||||
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
||||
y = torch.cat([y, sources[2]], dim=1)
|
||||
y = self.upconv2(y)
|
||||
|
||||
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
||||
y = torch.cat([y, sources[3]], dim=1)
|
||||
y = self.upconv3(y)
|
||||
|
||||
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
||||
y = torch.cat([y, sources[4]], dim=1)
|
||||
feature = self.upconv4(y)
|
||||
|
||||
y = self.conv_cls(feature)
|
||||
|
||||
return y.permute(0, 2, 3, 1), feature
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = CRAFT(pretrained=True).cuda()
|
||||
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
|
||||
print(output.shape)
|
||||
77
trainer/craft/model/vgg16_bn.py
Normal file
77
trainer/craft/model/vgg16_bn.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torchvision
|
||||
from torchvision import models
|
||||
from packaging import version
|
||||
|
||||
def init_weights(modules):
|
||||
for m in modules:
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.xavier_uniform_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
class vgg16_bn(torch.nn.Module):
|
||||
def __init__(self, pretrained=True, freeze=True):
|
||||
super(vgg16_bn, self).__init__()
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.13'):
|
||||
vgg_pretrained_features = models.vgg16_bn(
|
||||
weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
|
||||
).features
|
||||
else: # torchvision.__version__ < 0.13
|
||||
models.vgg.model_urls['vgg16_bn'] = models.vgg.model_urls['vgg16_bn'].replace('https://', 'http://')
|
||||
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
|
||||
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(12): # conv2_2
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 19): # conv3_3
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(19, 29): # conv4_3
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(29, 39): # conv5_3
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
# fc6, fc7 without atrous conv
|
||||
self.slice5 = torch.nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
|
||||
nn.Conv2d(1024, 1024, kernel_size=1)
|
||||
)
|
||||
|
||||
if not pretrained:
|
||||
init_weights(self.slice1.modules())
|
||||
init_weights(self.slice2.modules())
|
||||
init_weights(self.slice3.modules())
|
||||
init_weights(self.slice4.modules())
|
||||
|
||||
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
|
||||
|
||||
if freeze:
|
||||
for param in self.slice1.parameters(): # only first conv
|
||||
param.requires_grad= False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu2_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu3_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu5_3 = h
|
||||
h = self.slice5(h)
|
||||
h_fc7 = h
|
||||
return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2
|
||||
10
trainer/craft/requirements.txt
Normal file
10
trainer/craft/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
conda==4.10.3
|
||||
opencv-python==4.5.3.56
|
||||
Pillow==9.3.0
|
||||
Polygon3==3.0.9.1
|
||||
PyYAML==5.4.1
|
||||
scikit-image==0.17.2
|
||||
Shapely==1.8.0
|
||||
torch==1.13.1
|
||||
torchvision==0.10.0
|
||||
wandb==0.12.9
|
||||
7
trainer/craft/scripts/run_cde.sh
Normal file
7
trainer/craft/scripts/run_cde.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
# sed -i -e 's/\r$//' scripts/run_cde.sh
|
||||
EXP_NAME=custom_data_release_test_3
|
||||
yaml_path="config/$EXP_NAME.yaml"
|
||||
cp config/custom_data_train.yaml $yaml_path
|
||||
#CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=$EXP_NAME --port=2468
|
||||
CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=$EXP_NAME --port=2468
|
||||
rm "config/$EXP_NAME.yaml"
|
||||
479
trainer/craft/train.py
Normal file
479
trainer/craft/train.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
import yaml
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
|
||||
from config.load_config import load_yaml, DotDict
|
||||
from data.dataset import SynthTextDataSet, CustomDataset
|
||||
from loss.mseloss import Maploss_v2, Maploss_v3
|
||||
from model.craft import CRAFT
|
||||
from eval import main_eval
|
||||
from metrics.eval_det_iou import DetectionIoUEvaluator
|
||||
from utils.util import copyStateDict, save_parser
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, config, gpu, mode):
|
||||
|
||||
self.config = config
|
||||
self.gpu = gpu
|
||||
self.mode = mode
|
||||
self.net_param = self.get_load_param(gpu)
|
||||
|
||||
def get_synth_loader(self):
|
||||
|
||||
dataset = SynthTextDataSet(
|
||||
output_size=self.config.train.data.output_size,
|
||||
data_dir=self.config.train.synth_data_dir,
|
||||
saved_gt_dir=None,
|
||||
mean=self.config.train.data.mean,
|
||||
variance=self.config.train.data.variance,
|
||||
gauss_init_size=self.config.train.data.gauss_init_size,
|
||||
gauss_sigma=self.config.train.data.gauss_sigma,
|
||||
enlarge_region=self.config.train.data.enlarge_region,
|
||||
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
||||
aug=self.config.train.data.syn_aug,
|
||||
vis_test_dir=self.config.vis_test_dir,
|
||||
vis_opt=self.config.train.data.vis_opt,
|
||||
sample=self.config.train.data.syn_sample,
|
||||
)
|
||||
|
||||
syn_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
|
||||
shuffle=False,
|
||||
num_workers=self.config.train.num_workers,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
return syn_loader
|
||||
|
||||
def get_custom_dataset(self):
|
||||
|
||||
custom_dataset = CustomDataset(
|
||||
output_size=self.config.train.data.output_size,
|
||||
data_dir=self.config.data_root_dir,
|
||||
saved_gt_dir=None,
|
||||
mean=self.config.train.data.mean,
|
||||
variance=self.config.train.data.variance,
|
||||
gauss_init_size=self.config.train.data.gauss_init_size,
|
||||
gauss_sigma=self.config.train.data.gauss_sigma,
|
||||
enlarge_region=self.config.train.data.enlarge_region,
|
||||
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
||||
watershed_param=self.config.train.data.watershed,
|
||||
aug=self.config.train.data.custom_aug,
|
||||
vis_test_dir=self.config.vis_test_dir,
|
||||
sample=self.config.train.data.custom_sample,
|
||||
vis_opt=self.config.train.data.vis_opt,
|
||||
pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
|
||||
do_not_care_label=self.config.train.data.do_not_care_label,
|
||||
)
|
||||
|
||||
return custom_dataset
|
||||
|
||||
def get_load_param(self, gpu):
|
||||
|
||||
if self.config.train.ckpt_path is not None:
|
||||
map_location = "cuda:%d" % gpu
|
||||
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
|
||||
else:
|
||||
param = None
|
||||
|
||||
return param
|
||||
|
||||
def adjust_learning_rate(self, optimizer, gamma, step, lr):
|
||||
lr = lr * (gamma ** step)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
return param_group["lr"]
|
||||
|
||||
def get_loss(self):
|
||||
if self.config.train.loss == 2:
|
||||
criterion = Maploss_v2()
|
||||
elif self.config.train.loss == 3:
|
||||
criterion = Maploss_v3()
|
||||
else:
|
||||
raise Exception("Undefined loss")
|
||||
return criterion
|
||||
|
||||
def iou_eval(self, dataset, train_step, buffer, model):
|
||||
test_config = DotDict(self.config.test[dataset])
|
||||
|
||||
val_result_dir = os.path.join(
|
||||
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
|
||||
)
|
||||
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
|
||||
metrics = main_eval(
|
||||
None,
|
||||
self.config.train.backbone,
|
||||
test_config,
|
||||
evaluator,
|
||||
val_result_dir,
|
||||
buffer,
|
||||
model,
|
||||
self.mode,
|
||||
)
|
||||
if self.gpu == 0 and self.config.wandb_opt:
|
||||
wandb.log(
|
||||
{
|
||||
"{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
|
||||
"{} iou Precision".format(dataset): np.round(
|
||||
metrics["precision"], 3
|
||||
),
|
||||
"{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
|
||||
}
|
||||
)
|
||||
|
||||
def train(self, buffer_dict):
|
||||
|
||||
torch.cuda.set_device(self.gpu)
|
||||
|
||||
# MODEL -------------------------------------------------------------------------------------------------------#
|
||||
# SUPERVISION model
|
||||
if self.config.mode == "weak_supervision":
|
||||
if self.config.train.backbone == "vgg":
|
||||
supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
|
||||
else:
|
||||
raise Exception("Undefined architecture")
|
||||
|
||||
supervision_device = self.gpu
|
||||
if self.config.train.ckpt_path is not None:
|
||||
supervision_param = self.get_load_param(supervision_device)
|
||||
supervision_model.load_state_dict(
|
||||
copyStateDict(supervision_param["craft"])
|
||||
)
|
||||
supervision_model = supervision_model.to(f"cuda:{supervision_device}")
|
||||
print(f"Supervision model loading on : gpu {supervision_device}")
|
||||
else:
|
||||
supervision_model, supervision_device = None, None
|
||||
|
||||
# TRAIN model
|
||||
if self.config.train.backbone == "vgg":
|
||||
craft = CRAFT(pretrained=False, amp=self.config.train.amp)
|
||||
else:
|
||||
raise Exception("Undefined architecture")
|
||||
|
||||
if self.config.train.ckpt_path is not None:
|
||||
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
|
||||
|
||||
craft = craft.cuda()
|
||||
craft = torch.nn.DataParallel(craft)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# DATASET -----------------------------------------------------------------------------------------------------#
|
||||
|
||||
if self.config.train.use_synthtext:
|
||||
trn_syn_loader = self.get_synth_loader()
|
||||
batch_syn = iter(trn_syn_loader)
|
||||
|
||||
if self.config.train.real_dataset == "custom":
|
||||
trn_real_dataset = self.get_custom_dataset()
|
||||
else:
|
||||
raise Exception("Undefined dataset")
|
||||
|
||||
if self.config.mode == "weak_supervision":
|
||||
trn_real_dataset.update_model(supervision_model)
|
||||
trn_real_dataset.update_device(supervision_device)
|
||||
|
||||
trn_real_loader = torch.utils.data.DataLoader(
|
||||
trn_real_dataset,
|
||||
batch_size=self.config.train.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.config.train.num_workers,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# OPTIMIZER ---------------------------------------------------------------------------------------------------#
|
||||
optimizer = optim.Adam(
|
||||
craft.parameters(),
|
||||
lr=self.config.train.lr,
|
||||
weight_decay=self.config.train.weight_decay,
|
||||
)
|
||||
|
||||
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
|
||||
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
|
||||
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
|
||||
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
|
||||
|
||||
# LOSS --------------------------------------------------------------------------------------------------------#
|
||||
# mixed precision
|
||||
if self.config.train.amp:
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
if (
|
||||
self.config.train.ckpt_path is not None
|
||||
and self.config.train.st_iter != 0
|
||||
):
|
||||
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
criterion = self.get_loss()
|
||||
|
||||
# TRAIN -------------------------------------------------------------------------------------------------------#
|
||||
train_step = self.config.train.st_iter
|
||||
whole_training_step = self.config.train.end_iter
|
||||
update_lr_rate_step = 0
|
||||
training_lr = self.config.train.lr
|
||||
loss_value = 0
|
||||
batch_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
print(
|
||||
"================================ Train start ================================"
|
||||
)
|
||||
while train_step < whole_training_step:
|
||||
for (
|
||||
index,
|
||||
(
|
||||
images,
|
||||
region_scores,
|
||||
affinity_scores,
|
||||
confidence_masks,
|
||||
),
|
||||
) in enumerate(trn_real_loader):
|
||||
craft.train()
|
||||
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
|
||||
update_lr_rate_step += 1
|
||||
training_lr = self.adjust_learning_rate(
|
||||
optimizer,
|
||||
self.config.train.gamma,
|
||||
update_lr_rate_step,
|
||||
self.config.train.lr,
|
||||
)
|
||||
|
||||
images = images.cuda(non_blocking=True)
|
||||
region_scores = region_scores.cuda(non_blocking=True)
|
||||
affinity_scores = affinity_scores.cuda(non_blocking=True)
|
||||
confidence_masks = confidence_masks.cuda(non_blocking=True)
|
||||
|
||||
if self.config.train.use_synthtext:
|
||||
# Synth image load
|
||||
syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
|
||||
batch_syn
|
||||
)
|
||||
syn_image = syn_image.cuda(non_blocking=True)
|
||||
syn_region_label = syn_region_label.cuda(non_blocking=True)
|
||||
syn_affi_label = syn_affi_label.cuda(non_blocking=True)
|
||||
syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
|
||||
|
||||
# concat syn & custom image
|
||||
images = torch.cat((syn_image, images), 0)
|
||||
region_image_label = torch.cat(
|
||||
(syn_region_label, region_scores), 0
|
||||
)
|
||||
affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
|
||||
confidence_mask_label = torch.cat(
|
||||
(syn_confidence_mask, confidence_masks), 0
|
||||
)
|
||||
else:
|
||||
region_image_label = region_scores
|
||||
affinity_image_label = affinity_scores
|
||||
confidence_mask_label = confidence_masks
|
||||
|
||||
if self.config.train.amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
|
||||
output, _ = craft(images)
|
||||
out1 = output[:, :, :, 0]
|
||||
out2 = output[:, :, :, 1]
|
||||
|
||||
loss = criterion(
|
||||
region_image_label,
|
||||
affinity_image_label,
|
||||
out1,
|
||||
out2,
|
||||
confidence_mask_label,
|
||||
self.config.train.neg_rto,
|
||||
self.config.train.n_min_neg,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
else:
|
||||
output, _ = craft(images)
|
||||
out1 = output[:, :, :, 0]
|
||||
out2 = output[:, :, :, 1]
|
||||
loss = criterion(
|
||||
region_image_label,
|
||||
affinity_image_label,
|
||||
out1,
|
||||
out2,
|
||||
confidence_mask_label,
|
||||
self.config.train.neg_rto,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
end_time = time.time()
|
||||
loss_value += loss.item()
|
||||
batch_time += end_time - start_time
|
||||
|
||||
if train_step > 0 and train_step % 5 == 0:
|
||||
mean_loss = loss_value / 5
|
||||
loss_value = 0
|
||||
avg_batch_time = batch_time / 5
|
||||
batch_time = 0
|
||||
|
||||
print(
|
||||
"{}, training_step: {}|{}, learning rate: {:.8f}, "
|
||||
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
|
||||
time.strftime(
|
||||
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
|
||||
),
|
||||
train_step,
|
||||
whole_training_step,
|
||||
training_lr,
|
||||
mean_loss,
|
||||
avg_batch_time,
|
||||
)
|
||||
)
|
||||
|
||||
if self.config.wandb_opt:
|
||||
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
|
||||
|
||||
if (
|
||||
train_step % self.config.train.eval_interval == 0
|
||||
and train_step != 0
|
||||
):
|
||||
|
||||
craft.eval()
|
||||
|
||||
print("Saving state, index:", train_step)
|
||||
save_param_dic = {
|
||||
"iter": train_step,
|
||||
"craft": craft.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
|
||||
if self.config.train.amp:
|
||||
save_param_dic["scaler"] = scaler.state_dict()
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_amp_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
|
||||
torch.save(save_param_dic, save_param_path)
|
||||
|
||||
# validation
|
||||
self.iou_eval(
|
||||
"custom_data",
|
||||
train_step,
|
||||
buffer_dict["custom_data"],
|
||||
craft,
|
||||
)
|
||||
|
||||
train_step += 1
|
||||
if train_step >= whole_training_step:
|
||||
break
|
||||
|
||||
if self.config.mode == "weak_supervision":
|
||||
state_dict = craft.module.state_dict()
|
||||
supervision_model.load_state_dict(state_dict)
|
||||
trn_real_dataset.update_model(supervision_model)
|
||||
|
||||
# save last model
|
||||
save_param_dic = {
|
||||
"iter": train_step,
|
||||
"craft": craft.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
save_param_path = (
|
||||
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
|
||||
)
|
||||
|
||||
if self.config.train.amp:
|
||||
save_param_dic["scaler"] = scaler.state_dict()
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_amp_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
torch.save(save_param_dic, save_param_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CRAFT custom data train")
|
||||
parser.add_argument(
|
||||
"--yaml",
|
||||
"--yaml_file_name",
|
||||
default="custom_data_train",
|
||||
type=str,
|
||||
help="Load configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", "--use ddp port", default="2346", type=str, help="Port number"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load configure
|
||||
exp_name = args.yaml
|
||||
config = load_yaml(args.yaml)
|
||||
|
||||
print("-" * 20 + " Options " + "-" * 20)
|
||||
print(yaml.dump(config))
|
||||
print("-" * 40)
|
||||
|
||||
# Make result_dir
|
||||
res_dir = os.path.join(config["results_dir"], args.yaml)
|
||||
config["results_dir"] = res_dir
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
# Duplicate yaml file to result_dir
|
||||
shutil.copy(
|
||||
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
|
||||
)
|
||||
|
||||
if config["mode"] == "weak_supervision":
|
||||
mode = "weak_supervision"
|
||||
else:
|
||||
mode = None
|
||||
|
||||
|
||||
# Apply config to wandb
|
||||
if config["wandb_opt"]:
|
||||
wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
|
||||
wandb.config.update(config)
|
||||
|
||||
config = DotDict(config)
|
||||
|
||||
# Start train
|
||||
buffer_dict = {"custom_data":None}
|
||||
trainer = Trainer(config, 0, mode)
|
||||
trainer.train(buffer_dict)
|
||||
|
||||
if config["wandb_opt"]:
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
409
trainer/craft/trainSynth.py
Normal file
409
trainer/craft/trainSynth.py
Normal file
@@ -0,0 +1,409 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import yaml
|
||||
import multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
|
||||
from config.load_config import load_yaml, DotDict
|
||||
from data.dataset import SynthTextDataSet
|
||||
from loss.mseloss import Maploss_v2, Maploss_v3
|
||||
from model.craft import CRAFT
|
||||
from metrics.eval_det_iou import DetectionIoUEvaluator
|
||||
from eval import main_eval
|
||||
from utils.util import copyStateDict, save_parser
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, config, gpu):
|
||||
|
||||
self.config = config
|
||||
self.gpu = gpu
|
||||
self.mode = None
|
||||
self.trn_loader, self.trn_sampler = self.get_trn_loader()
|
||||
self.net_param = self.get_load_param(gpu)
|
||||
|
||||
def get_trn_loader(self):
|
||||
|
||||
dataset = SynthTextDataSet(
|
||||
output_size=self.config.train.data.output_size,
|
||||
data_dir=self.config.data_dir.synthtext,
|
||||
saved_gt_dir=None,
|
||||
mean=self.config.train.data.mean,
|
||||
variance=self.config.train.data.variance,
|
||||
gauss_init_size=self.config.train.data.gauss_init_size,
|
||||
gauss_sigma=self.config.train.data.gauss_sigma,
|
||||
enlarge_region=self.config.train.data.enlarge_region,
|
||||
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
||||
aug=self.config.train.data.syn_aug,
|
||||
vis_test_dir=self.config.vis_test_dir,
|
||||
vis_opt=self.config.train.data.vis_opt,
|
||||
sample=self.config.train.data.syn_sample,
|
||||
)
|
||||
|
||||
trn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
|
||||
trn_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.train.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.config.train.num_workers,
|
||||
sampler=trn_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
return trn_loader, trn_sampler
|
||||
|
||||
def get_load_param(self, gpu):
|
||||
if self.config.train.ckpt_path is not None:
|
||||
map_location = {"cuda:%d" % 0: "cuda:%d" % gpu}
|
||||
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
|
||||
else:
|
||||
param = None
|
||||
return param
|
||||
|
||||
def adjust_learning_rate(self, optimizer, gamma, step, lr):
|
||||
lr = lr * (gamma ** step)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
return param_group["lr"]
|
||||
|
||||
def get_loss(self):
|
||||
if self.config.train.loss == 2:
|
||||
criterion = Maploss_v2()
|
||||
elif self.config.train.loss == 3:
|
||||
criterion = Maploss_v3()
|
||||
else:
|
||||
raise Exception("Undefined loss")
|
||||
return criterion
|
||||
|
||||
def iou_eval(self, dataset, train_step, save_param_path, buffer, model):
|
||||
test_config = DotDict(self.config.test[dataset])
|
||||
|
||||
val_result_dir = os.path.join(
|
||||
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
|
||||
)
|
||||
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
|
||||
metrics = main_eval(
|
||||
save_param_path,
|
||||
self.config.train.backbone,
|
||||
test_config,
|
||||
evaluator,
|
||||
val_result_dir,
|
||||
buffer,
|
||||
model,
|
||||
self.mode,
|
||||
)
|
||||
if self.gpu == 0 and self.config.wandb_opt:
|
||||
wandb.log(
|
||||
{
|
||||
"{} IoU Recall".format(dataset): np.round(metrics["recall"], 3),
|
||||
"{} IoU Precision".format(dataset): np.round(
|
||||
metrics["precision"], 3
|
||||
),
|
||||
"{} IoU F1-score".format(dataset): np.round(metrics["hmean"], 3),
|
||||
}
|
||||
)
|
||||
|
||||
def train(self, buffer_dict):
|
||||
torch.cuda.set_device(self.gpu)
|
||||
|
||||
# DATASET -----------------------------------------------------------------------------------------------------#
|
||||
trn_loader = self.trn_loader
|
||||
|
||||
# MODEL -------------------------------------------------------------------------------------------------------#
|
||||
if self.config.train.backbone == "vgg":
|
||||
craft = CRAFT(pretrained=True, amp=self.config.train.amp)
|
||||
else:
|
||||
raise Exception("Undefined architecture")
|
||||
|
||||
if self.config.train.ckpt_path is not None:
|
||||
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
|
||||
craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
|
||||
craft = craft.cuda()
|
||||
craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# OPTIMIZER----------------------------------------------------------------------------------------------------#
|
||||
|
||||
optimizer = optim.Adam(
|
||||
craft.parameters(),
|
||||
lr=self.config.train.lr,
|
||||
weight_decay=self.config.train.weight_decay,
|
||||
)
|
||||
|
||||
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
|
||||
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
|
||||
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
|
||||
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
|
||||
|
||||
# LOSS --------------------------------------------------------------------------------------------------------#
|
||||
# mixed precision
|
||||
if self.config.train.amp:
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# load model
|
||||
if (
|
||||
self.config.train.ckpt_path is not None
|
||||
and self.config.train.st_iter != 0
|
||||
):
|
||||
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
criterion = self.get_loss()
|
||||
|
||||
# TRAIN -------------------------------------------------------------------------------------------------------#
|
||||
train_step = self.config.train.st_iter
|
||||
whole_training_step = self.config.train.end_iter
|
||||
update_lr_rate_step = 0
|
||||
training_lr = self.config.train.lr
|
||||
loss_value = 0
|
||||
batch_time = 0
|
||||
epoch = 0
|
||||
start_time = time.time()
|
||||
|
||||
while train_step < whole_training_step:
|
||||
self.trn_sampler.set_epoch(train_step)
|
||||
for (
|
||||
index,
|
||||
(image, region_image, affinity_image, confidence_mask,),
|
||||
) in enumerate(trn_loader):
|
||||
craft.train()
|
||||
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
|
||||
update_lr_rate_step += 1
|
||||
training_lr = self.adjust_learning_rate(
|
||||
optimizer,
|
||||
self.config.train.gamma,
|
||||
update_lr_rate_step,
|
||||
self.config.train.lr,
|
||||
)
|
||||
|
||||
images = image.cuda(non_blocking=True)
|
||||
region_image_label = region_image.cuda(non_blocking=True)
|
||||
affinity_image_label = affinity_image.cuda(non_blocking=True)
|
||||
confidence_mask_label = confidence_mask.cuda(non_blocking=True)
|
||||
|
||||
if self.config.train.amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
|
||||
output, _ = craft(images)
|
||||
out1 = output[:, :, :, 0]
|
||||
out2 = output[:, :, :, 1]
|
||||
|
||||
loss = criterion(
|
||||
region_image_label,
|
||||
affinity_image_label,
|
||||
out1,
|
||||
out2,
|
||||
confidence_mask_label,
|
||||
self.config.train.neg_rto,
|
||||
self.config.train.n_min_neg,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
else:
|
||||
output, _ = craft(images)
|
||||
out1 = output[:, :, :, 0]
|
||||
out2 = output[:, :, :, 1]
|
||||
loss = criterion(
|
||||
region_image_label,
|
||||
affinity_image_label,
|
||||
out1,
|
||||
out2,
|
||||
confidence_mask_label,
|
||||
self.config.train.neg_rto,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
end_time = time.time()
|
||||
loss_value += loss.item()
|
||||
batch_time += end_time - start_time
|
||||
|
||||
if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
|
||||
mean_loss = loss_value / 5
|
||||
loss_value = 0
|
||||
avg_batch_time = batch_time / 5
|
||||
batch_time = 0
|
||||
|
||||
print(
|
||||
"{}, training_step: {}|{}, learning rate: {:.8f}, "
|
||||
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
|
||||
time.strftime(
|
||||
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
|
||||
),
|
||||
train_step,
|
||||
whole_training_step,
|
||||
training_lr,
|
||||
mean_loss,
|
||||
avg_batch_time,
|
||||
)
|
||||
)
|
||||
if self.gpu == 0 and self.config.wandb_opt:
|
||||
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
|
||||
|
||||
if (
|
||||
train_step % self.config.train.eval_interval == 0
|
||||
and train_step != 0
|
||||
):
|
||||
|
||||
# initialize all buffer value with zero
|
||||
if self.gpu == 0:
|
||||
for buffer in buffer_dict.values():
|
||||
for i in range(len(buffer)):
|
||||
buffer[i] = None
|
||||
|
||||
print("Saving state, index:", train_step)
|
||||
save_param_dic = {
|
||||
"iter": train_step,
|
||||
"craft": craft.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
|
||||
if self.config.train.amp:
|
||||
save_param_dic["scaler"] = scaler.state_dict()
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_amp_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
|
||||
if self.gpu == 0:
|
||||
torch.save(save_param_dic, save_param_path)
|
||||
|
||||
# validation
|
||||
self.iou_eval(
|
||||
"icdar2013",
|
||||
train_step,
|
||||
save_param_path,
|
||||
buffer_dict["icdar2013"],
|
||||
craft,
|
||||
)
|
||||
|
||||
train_step += 1
|
||||
if train_step >= whole_training_step:
|
||||
break
|
||||
epoch += 1
|
||||
|
||||
# save last model
|
||||
if self.gpu == 0:
|
||||
save_param_dic = {
|
||||
"iter": train_step,
|
||||
"craft": craft.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
save_param_path = (
|
||||
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
|
||||
)
|
||||
|
||||
if self.config.train.amp:
|
||||
save_param_dic["scaler"] = scaler.state_dict()
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_amp_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
torch.save(save_param_dic, save_param_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CRAFT SynthText Train")
|
||||
parser.add_argument(
|
||||
"--yaml",
|
||||
"--yaml_file_name",
|
||||
default="syn_train",
|
||||
type=str,
|
||||
help="Load configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", "--use ddp port", default="2646", type=str, help="Load configuration"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load configure
|
||||
exp_name = args.yaml
|
||||
config = load_yaml(args.yaml)
|
||||
|
||||
print("-" * 20 + " Options " + "-" * 20)
|
||||
print(yaml.dump(config))
|
||||
print("-" * 40)
|
||||
|
||||
# Make result_dir
|
||||
res_dir = os.path.join(config["results_dir"], args.yaml)
|
||||
config["results_dir"] = res_dir
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
# Duplicate yaml file to result_dir
|
||||
shutil.copy(
|
||||
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
|
||||
)
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
print(f"Total device num : {ngpus_per_node}")
|
||||
|
||||
manager = mp.Manager()
|
||||
buffer1 = manager.list([None] * config["test"]["icdar2013"]["test_set_size"])
|
||||
buffer_dict = {"icdar2013": buffer1}
|
||||
torch.multiprocessing.spawn(
|
||||
main_worker,
|
||||
nprocs=ngpus_per_node,
|
||||
args=(args.port, ngpus_per_node, config, buffer_dict, exp_name,),
|
||||
)
|
||||
print('flag5')
|
||||
|
||||
|
||||
def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name):
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:" + port,
|
||||
world_size=ngpus_per_node,
|
||||
rank=gpu,
|
||||
)
|
||||
|
||||
# Apply config to wandb
|
||||
if gpu == 0 and config["wandb_opt"]:
|
||||
wandb.init(project="craft-stage1", entity="gmuffiness", name=exp_name)
|
||||
wandb.config.update(config)
|
||||
|
||||
batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
|
||||
config["train"]["batch_size"] = batch_size
|
||||
config = DotDict(config)
|
||||
|
||||
# Start train
|
||||
trainer = Trainer(config, gpu)
|
||||
trainer.train(buffer_dict)
|
||||
|
||||
if gpu == 0 and config["wandb_opt"]:
|
||||
wandb.finish()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
523
trainer/craft/train_distributed.py
Normal file
523
trainer/craft/train_distributed.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
import yaml
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
|
||||
from config.load_config import load_yaml, DotDict
|
||||
from data.dataset import SynthTextDataSet, CustomDataset
|
||||
from loss.mseloss import Maploss_v2, Maploss_v3
|
||||
from model.craft import CRAFT
|
||||
from eval import main_eval
|
||||
from metrics.eval_det_iou import DetectionIoUEvaluator
|
||||
from utils.util import copyStateDict, save_parser
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, config, gpu, mode):
|
||||
|
||||
self.config = config
|
||||
self.gpu = gpu
|
||||
self.mode = mode
|
||||
self.net_param = self.get_load_param(gpu)
|
||||
|
||||
def get_synth_loader(self):
|
||||
|
||||
dataset = SynthTextDataSet(
|
||||
output_size=self.config.train.data.output_size,
|
||||
data_dir=self.config.train.synth_data_dir,
|
||||
saved_gt_dir=None,
|
||||
mean=self.config.train.data.mean,
|
||||
variance=self.config.train.data.variance,
|
||||
gauss_init_size=self.config.train.data.gauss_init_size,
|
||||
gauss_sigma=self.config.train.data.gauss_sigma,
|
||||
enlarge_region=self.config.train.data.enlarge_region,
|
||||
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
||||
aug=self.config.train.data.syn_aug,
|
||||
vis_test_dir=self.config.vis_test_dir,
|
||||
vis_opt=self.config.train.data.vis_opt,
|
||||
sample=self.config.train.data.syn_sample,
|
||||
)
|
||||
|
||||
syn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
|
||||
syn_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
|
||||
shuffle=False,
|
||||
num_workers=self.config.train.num_workers,
|
||||
sampler=syn_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
return syn_loader
|
||||
|
||||
def get_custom_dataset(self):
|
||||
|
||||
custom_dataset = CustomDataset(
|
||||
output_size=self.config.train.data.output_size,
|
||||
data_dir=self.config.data_root_dir,
|
||||
saved_gt_dir=None,
|
||||
mean=self.config.train.data.mean,
|
||||
variance=self.config.train.data.variance,
|
||||
gauss_init_size=self.config.train.data.gauss_init_size,
|
||||
gauss_sigma=self.config.train.data.gauss_sigma,
|
||||
enlarge_region=self.config.train.data.enlarge_region,
|
||||
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
||||
watershed_param=self.config.train.data.watershed,
|
||||
aug=self.config.train.data.custom_aug,
|
||||
vis_test_dir=self.config.vis_test_dir,
|
||||
sample=self.config.train.data.custom_sample,
|
||||
vis_opt=self.config.train.data.vis_opt,
|
||||
pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
|
||||
do_not_care_label=self.config.train.data.do_not_care_label,
|
||||
)
|
||||
|
||||
return custom_dataset
|
||||
|
||||
def get_load_param(self, gpu):
|
||||
|
||||
if self.config.train.ckpt_path is not None:
|
||||
map_location = "cuda:%d" % gpu
|
||||
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
|
||||
else:
|
||||
param = None
|
||||
|
||||
return param
|
||||
|
||||
def adjust_learning_rate(self, optimizer, gamma, step, lr):
|
||||
lr = lr * (gamma ** step)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
return param_group["lr"]
|
||||
|
||||
def get_loss(self):
|
||||
if self.config.train.loss == 2:
|
||||
criterion = Maploss_v2()
|
||||
elif self.config.train.loss == 3:
|
||||
criterion = Maploss_v3()
|
||||
else:
|
||||
raise Exception("Undefined loss")
|
||||
return criterion
|
||||
|
||||
def iou_eval(self, dataset, train_step, buffer, model):
|
||||
test_config = DotDict(self.config.test[dataset])
|
||||
|
||||
val_result_dir = os.path.join(
|
||||
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
|
||||
)
|
||||
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
|
||||
metrics = main_eval(
|
||||
None,
|
||||
self.config.train.backbone,
|
||||
test_config,
|
||||
evaluator,
|
||||
val_result_dir,
|
||||
buffer,
|
||||
model,
|
||||
self.mode,
|
||||
)
|
||||
if self.gpu == 0 and self.config.wandb_opt:
|
||||
wandb.log(
|
||||
{
|
||||
"{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
|
||||
"{} iou Precision".format(dataset): np.round(
|
||||
metrics["precision"], 3
|
||||
),
|
||||
"{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
|
||||
}
|
||||
)
|
||||
|
||||
def train(self, buffer_dict):
|
||||
|
||||
torch.cuda.set_device(self.gpu)
|
||||
total_gpu_num = torch.cuda.device_count()
|
||||
|
||||
# MODEL -------------------------------------------------------------------------------------------------------#
|
||||
# SUPERVISION model
|
||||
if self.config.mode == "weak_supervision":
|
||||
if self.config.train.backbone == "vgg":
|
||||
supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
|
||||
else:
|
||||
raise Exception("Undefined architecture")
|
||||
|
||||
# NOTE: only work on half GPU assign train / half GPU assign supervision setting
|
||||
supervision_device = total_gpu_num // 2 + self.gpu
|
||||
if self.config.train.ckpt_path is not None:
|
||||
supervision_param = self.get_load_param(supervision_device)
|
||||
supervision_model.load_state_dict(
|
||||
copyStateDict(supervision_param["craft"])
|
||||
)
|
||||
supervision_model = supervision_model.to(f"cuda:{supervision_device}")
|
||||
print(f"Supervision model loading on : gpu {supervision_device}")
|
||||
else:
|
||||
supervision_model, supervision_device = None, None
|
||||
|
||||
# TRAIN model
|
||||
if self.config.train.backbone == "vgg":
|
||||
craft = CRAFT(pretrained=False, amp=self.config.train.amp)
|
||||
else:
|
||||
raise Exception("Undefined architecture")
|
||||
|
||||
if self.config.train.ckpt_path is not None:
|
||||
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
|
||||
|
||||
craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
|
||||
craft = craft.cuda()
|
||||
craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# DATASET -----------------------------------------------------------------------------------------------------#
|
||||
|
||||
if self.config.train.use_synthtext:
|
||||
trn_syn_loader = self.get_synth_loader()
|
||||
batch_syn = iter(trn_syn_loader)
|
||||
|
||||
if self.config.train.real_dataset == "custom":
|
||||
trn_real_dataset = self.get_custom_dataset()
|
||||
else:
|
||||
raise Exception("Undefined dataset")
|
||||
|
||||
if self.config.mode == "weak_supervision":
|
||||
trn_real_dataset.update_model(supervision_model)
|
||||
trn_real_dataset.update_device(supervision_device)
|
||||
|
||||
trn_real_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
trn_real_dataset
|
||||
)
|
||||
trn_real_loader = torch.utils.data.DataLoader(
|
||||
trn_real_dataset,
|
||||
batch_size=self.config.train.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.config.train.num_workers,
|
||||
sampler=trn_real_sampler,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# OPTIMIZER ---------------------------------------------------------------------------------------------------#
|
||||
optimizer = optim.Adam(
|
||||
craft.parameters(),
|
||||
lr=self.config.train.lr,
|
||||
weight_decay=self.config.train.weight_decay,
|
||||
)
|
||||
|
||||
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
|
||||
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
|
||||
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
|
||||
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
|
||||
|
||||
# LOSS --------------------------------------------------------------------------------------------------------#
|
||||
# mixed precision
|
||||
if self.config.train.amp:
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
if (
|
||||
self.config.train.ckpt_path is not None
|
||||
and self.config.train.st_iter != 0
|
||||
):
|
||||
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
criterion = self.get_loss()
|
||||
|
||||
# TRAIN -------------------------------------------------------------------------------------------------------#
|
||||
train_step = self.config.train.st_iter
|
||||
whole_training_step = self.config.train.end_iter
|
||||
update_lr_rate_step = 0
|
||||
training_lr = self.config.train.lr
|
||||
loss_value = 0
|
||||
batch_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
print(
|
||||
"================================ Train start ================================"
|
||||
)
|
||||
while train_step < whole_training_step:
|
||||
trn_real_sampler.set_epoch(train_step)
|
||||
for (
|
||||
index,
|
||||
(
|
||||
images,
|
||||
region_scores,
|
||||
affinity_scores,
|
||||
confidence_masks,
|
||||
),
|
||||
) in enumerate(trn_real_loader):
|
||||
craft.train()
|
||||
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
|
||||
update_lr_rate_step += 1
|
||||
training_lr = self.adjust_learning_rate(
|
||||
optimizer,
|
||||
self.config.train.gamma,
|
||||
update_lr_rate_step,
|
||||
self.config.train.lr,
|
||||
)
|
||||
|
||||
images = images.cuda(non_blocking=True)
|
||||
region_scores = region_scores.cuda(non_blocking=True)
|
||||
affinity_scores = affinity_scores.cuda(non_blocking=True)
|
||||
confidence_masks = confidence_masks.cuda(non_blocking=True)
|
||||
|
||||
if self.config.train.use_synthtext:
|
||||
# Synth image load
|
||||
syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
|
||||
batch_syn
|
||||
)
|
||||
syn_image = syn_image.cuda(non_blocking=True)
|
||||
syn_region_label = syn_region_label.cuda(non_blocking=True)
|
||||
syn_affi_label = syn_affi_label.cuda(non_blocking=True)
|
||||
syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
|
||||
|
||||
# concat syn & custom image
|
||||
images = torch.cat((syn_image, images), 0)
|
||||
region_image_label = torch.cat(
|
||||
(syn_region_label, region_scores), 0
|
||||
)
|
||||
affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
|
||||
confidence_mask_label = torch.cat(
|
||||
(syn_confidence_mask, confidence_masks), 0
|
||||
)
|
||||
else:
|
||||
region_image_label = region_scores
|
||||
affinity_image_label = affinity_scores
|
||||
confidence_mask_label = confidence_masks
|
||||
|
||||
if self.config.train.amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
|
||||
output, _ = craft(images)
|
||||
out1 = output[:, :, :, 0]
|
||||
out2 = output[:, :, :, 1]
|
||||
|
||||
loss = criterion(
|
||||
region_image_label,
|
||||
affinity_image_label,
|
||||
out1,
|
||||
out2,
|
||||
confidence_mask_label,
|
||||
self.config.train.neg_rto,
|
||||
self.config.train.n_min_neg,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
else:
|
||||
output, _ = craft(images)
|
||||
out1 = output[:, :, :, 0]
|
||||
out2 = output[:, :, :, 1]
|
||||
loss = criterion(
|
||||
region_image_label,
|
||||
affinity_image_label,
|
||||
out1,
|
||||
out2,
|
||||
confidence_mask_label,
|
||||
self.config.train.neg_rto,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
end_time = time.time()
|
||||
loss_value += loss.item()
|
||||
batch_time += end_time - start_time
|
||||
|
||||
if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
|
||||
mean_loss = loss_value / 5
|
||||
loss_value = 0
|
||||
avg_batch_time = batch_time / 5
|
||||
batch_time = 0
|
||||
|
||||
print(
|
||||
"{}, training_step: {}|{}, learning rate: {:.8f}, "
|
||||
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
|
||||
time.strftime(
|
||||
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
|
||||
),
|
||||
train_step,
|
||||
whole_training_step,
|
||||
training_lr,
|
||||
mean_loss,
|
||||
avg_batch_time,
|
||||
)
|
||||
)
|
||||
|
||||
if self.gpu == 0 and self.config.wandb_opt:
|
||||
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
|
||||
|
||||
if (
|
||||
train_step % self.config.train.eval_interval == 0
|
||||
and train_step != 0
|
||||
):
|
||||
|
||||
craft.eval()
|
||||
# initialize all buffer value with zero
|
||||
if self.gpu == 0:
|
||||
for buffer in buffer_dict.values():
|
||||
for i in range(len(buffer)):
|
||||
buffer[i] = None
|
||||
|
||||
print("Saving state, index:", train_step)
|
||||
save_param_dic = {
|
||||
"iter": train_step,
|
||||
"craft": craft.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
|
||||
if self.config.train.amp:
|
||||
save_param_dic["scaler"] = scaler.state_dict()
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_amp_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
|
||||
torch.save(save_param_dic, save_param_path)
|
||||
|
||||
# validation
|
||||
self.iou_eval(
|
||||
"custom_data",
|
||||
train_step,
|
||||
buffer_dict["custom_data"],
|
||||
craft,
|
||||
)
|
||||
|
||||
train_step += 1
|
||||
if train_step >= whole_training_step:
|
||||
break
|
||||
|
||||
if self.config.mode == "weak_supervision":
|
||||
state_dict = craft.module.state_dict()
|
||||
supervision_model.load_state_dict(state_dict)
|
||||
trn_real_dataset.update_model(supervision_model)
|
||||
|
||||
# save last model
|
||||
if self.gpu == 0:
|
||||
save_param_dic = {
|
||||
"iter": train_step,
|
||||
"craft": craft.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
save_param_path = (
|
||||
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
|
||||
)
|
||||
|
||||
if self.config.train.amp:
|
||||
save_param_dic["scaler"] = scaler.state_dict()
|
||||
save_param_path = (
|
||||
self.config.results_dir
|
||||
+ "/CRAFT_clr_amp_"
|
||||
+ repr(train_step)
|
||||
+ ".pth"
|
||||
)
|
||||
torch.save(save_param_dic, save_param_path)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CRAFT custom data train")
|
||||
parser.add_argument(
|
||||
"--yaml",
|
||||
"--yaml_file_name",
|
||||
default="custom_data_train",
|
||||
type=str,
|
||||
help="Load configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", "--use ddp port", default="2346", type=str, help="Port number"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load configure
|
||||
exp_name = args.yaml
|
||||
config = load_yaml(args.yaml)
|
||||
|
||||
print("-" * 20 + " Options " + "-" * 20)
|
||||
print(yaml.dump(config))
|
||||
print("-" * 40)
|
||||
|
||||
# Make result_dir
|
||||
res_dir = os.path.join(config["results_dir"], args.yaml)
|
||||
config["results_dir"] = res_dir
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
|
||||
# Duplicate yaml file to result_dir
|
||||
shutil.copy(
|
||||
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
|
||||
)
|
||||
|
||||
if config["mode"] == "weak_supervision":
|
||||
# NOTE: half GPU assign train / half GPU assign supervision setting
|
||||
ngpus_per_node = torch.cuda.device_count() // 2
|
||||
mode = "weak_supervision"
|
||||
else:
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
mode = None
|
||||
|
||||
print(f"Total process num : {ngpus_per_node}")
|
||||
|
||||
manager = mp.Manager()
|
||||
buffer1 = manager.list([None] * config["test"]["custom_data"]["test_set_size"])
|
||||
|
||||
buffer_dict = {"custom_data": buffer1}
|
||||
torch.multiprocessing.spawn(
|
||||
main_worker,
|
||||
nprocs=ngpus_per_node,
|
||||
args=(args.port, ngpus_per_node, config, buffer_dict, exp_name, mode,),
|
||||
)
|
||||
|
||||
|
||||
def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name, mode):
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:" + port,
|
||||
world_size=ngpus_per_node,
|
||||
rank=gpu,
|
||||
)
|
||||
|
||||
# Apply config to wandb
|
||||
if gpu == 0 and config["wandb_opt"]:
|
||||
wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
|
||||
wandb.config.update(config)
|
||||
|
||||
batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
|
||||
config["train"]["batch_size"] = batch_size
|
||||
config = DotDict(config)
|
||||
|
||||
# Start train
|
||||
trainer = Trainer(config, gpu, mode)
|
||||
trainer.train(buffer_dict)
|
||||
|
||||
if gpu == 0:
|
||||
if config["wandb_opt"]:
|
||||
wandb.finish()
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
345
trainer/craft/utils/craft_utils.py
Normal file
345
trainer/craft/utils/craft_utils.py
Normal file
@@ -0,0 +1,345 @@
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
from data import imgproc
|
||||
|
||||
""" auxilary functions """
|
||||
# unwarp corodinates
|
||||
|
||||
|
||||
|
||||
|
||||
def warpCoord(Minv, pt):
|
||||
out = np.matmul(Minv, (pt[0], pt[1], 1))
|
||||
return np.array([out[0]/out[2], out[1]/out[2]])
|
||||
""" end of auxilary functions """
|
||||
|
||||
def test():
|
||||
print('pass')
|
||||
|
||||
|
||||
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
|
||||
# prepare data
|
||||
linkmap = linkmap.copy()
|
||||
textmap = textmap.copy()
|
||||
img_h, img_w = textmap.shape
|
||||
|
||||
""" labeling method """
|
||||
ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
|
||||
ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
|
||||
|
||||
text_score_comb = np.clip(text_score + link_score, 0, 1)
|
||||
nLabels, labels, stats, centroids = \
|
||||
cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
|
||||
|
||||
det = []
|
||||
mapper = []
|
||||
for k in range(1,nLabels):
|
||||
# size filtering
|
||||
size = stats[k, cv2.CC_STAT_AREA]
|
||||
if size < 10: continue
|
||||
|
||||
# thresholding
|
||||
if np.max(textmap[labels==k]) < text_threshold: continue
|
||||
|
||||
# make segmentation map
|
||||
segmap = np.zeros(textmap.shape, dtype=np.uint8)
|
||||
segmap[labels==k] = 255
|
||||
segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
|
||||
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
|
||||
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
|
||||
niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
|
||||
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
|
||||
# boundary check
|
||||
if sx < 0 : sx = 0
|
||||
if sy < 0 : sy = 0
|
||||
if ex >= img_w: ex = img_w
|
||||
if ey >= img_h: ey = img_h
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
|
||||
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1)
|
||||
#kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5))
|
||||
#segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1)
|
||||
|
||||
|
||||
# make box
|
||||
np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
|
||||
rectangle = cv2.minAreaRect(np_contours)
|
||||
box = cv2.boxPoints(rectangle)
|
||||
|
||||
# align diamond-shape
|
||||
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
|
||||
box_ratio = max(w, h) / (min(w, h) + 1e-5)
|
||||
if abs(1 - box_ratio) <= 0.1:
|
||||
l, r = min(np_contours[:,0]), max(np_contours[:,0])
|
||||
t, b = min(np_contours[:,1]), max(np_contours[:,1])
|
||||
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
|
||||
|
||||
# make clock-wise order
|
||||
startidx = box.sum(axis=1).argmin()
|
||||
box = np.roll(box, 4-startidx, 0)
|
||||
box = np.array(box)
|
||||
|
||||
det.append(box)
|
||||
mapper.append(k)
|
||||
|
||||
return det, labels, mapper
|
||||
|
||||
def getPoly_core(boxes, labels, mapper, linkmap):
|
||||
# configs
|
||||
num_cp = 5
|
||||
max_len_ratio = 0.7
|
||||
expand_ratio = 1.45
|
||||
max_r = 2.0
|
||||
step_r = 0.2
|
||||
|
||||
polys = []
|
||||
for k, box in enumerate(boxes):
|
||||
# size filter for small instance
|
||||
w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
|
||||
if w < 30 or h < 30:
|
||||
polys.append(None); continue
|
||||
|
||||
# warp image
|
||||
tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
|
||||
M = cv2.getPerspectiveTransform(box, tar)
|
||||
word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
|
||||
try:
|
||||
Minv = np.linalg.inv(M)
|
||||
except:
|
||||
polys.append(None); continue
|
||||
|
||||
# binarization for selected label
|
||||
cur_label = mapper[k]
|
||||
word_label[word_label != cur_label] = 0
|
||||
word_label[word_label > 0] = 1
|
||||
|
||||
""" Polygon generation """
|
||||
# find top/bottom contours
|
||||
cp = []
|
||||
max_len = -1
|
||||
for i in range(w):
|
||||
region = np.where(word_label[:,i] != 0)[0]
|
||||
if len(region) < 2 : continue
|
||||
cp.append((i, region[0], region[-1]))
|
||||
length = region[-1] - region[0] + 1
|
||||
if length > max_len: max_len = length
|
||||
|
||||
# pass if max_len is similar to h
|
||||
if h * max_len_ratio < max_len:
|
||||
polys.append(None); continue
|
||||
|
||||
# get pivot points with fixed length
|
||||
tot_seg = num_cp * 2 + 1
|
||||
seg_w = w / tot_seg # segment width
|
||||
pp = [None] * num_cp # init pivot points
|
||||
cp_section = [[0, 0]] * tot_seg
|
||||
seg_height = [0] * num_cp
|
||||
seg_num = 0
|
||||
num_sec = 0
|
||||
prev_h = -1
|
||||
for i in range(0,len(cp)):
|
||||
(x, sy, ey) = cp[i]
|
||||
if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
|
||||
# average previous segment
|
||||
if num_sec == 0: break
|
||||
cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
|
||||
num_sec = 0
|
||||
|
||||
# reset variables
|
||||
seg_num += 1
|
||||
prev_h = -1
|
||||
|
||||
# accumulate center points
|
||||
cy = (sy + ey) * 0.5
|
||||
cur_h = ey - sy + 1
|
||||
cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
|
||||
num_sec += 1
|
||||
|
||||
if seg_num % 2 == 0: continue # No polygon area
|
||||
|
||||
if prev_h < cur_h:
|
||||
pp[int((seg_num - 1)/2)] = (x, cy)
|
||||
seg_height[int((seg_num - 1)/2)] = cur_h
|
||||
prev_h = cur_h
|
||||
|
||||
# processing last segment
|
||||
if num_sec != 0:
|
||||
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
|
||||
|
||||
# pass if num of pivots is not sufficient or segment widh is smaller than character height
|
||||
if None in pp or seg_w < np.max(seg_height) * 0.25:
|
||||
polys.append(None); continue
|
||||
|
||||
# calc median maximum of pivot points
|
||||
half_char_h = np.median(seg_height) * expand_ratio / 2
|
||||
|
||||
# calc gradiant and apply to make horizontal pivots
|
||||
new_pp = []
|
||||
for i, (x, cy) in enumerate(pp):
|
||||
dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
|
||||
dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
|
||||
if dx == 0: # gradient if zero
|
||||
new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
|
||||
continue
|
||||
rad = - math.atan2(dy, dx)
|
||||
c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
|
||||
new_pp.append([x - s, cy - c, x + s, cy + c])
|
||||
|
||||
# get edge points to cover character heatmaps
|
||||
isSppFound, isEppFound = False, False
|
||||
grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
|
||||
grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
|
||||
for r in np.arange(0.5, max_r, step_r):
|
||||
dx = 2 * half_char_h * r
|
||||
if not isSppFound:
|
||||
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
||||
dy = grad_s * dx
|
||||
p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
|
||||
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
||||
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
||||
spp = p
|
||||
isSppFound = True
|
||||
if not isEppFound:
|
||||
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
||||
dy = grad_e * dx
|
||||
p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
|
||||
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
||||
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
||||
epp = p
|
||||
isEppFound = True
|
||||
if isSppFound and isEppFound:
|
||||
break
|
||||
|
||||
# pass if boundary of polygon is not found
|
||||
if not (isSppFound and isEppFound):
|
||||
polys.append(None); continue
|
||||
|
||||
# make final polygon
|
||||
poly = []
|
||||
poly.append(warpCoord(Minv, (spp[0], spp[1])))
|
||||
for p in new_pp:
|
||||
poly.append(warpCoord(Minv, (p[0], p[1])))
|
||||
poly.append(warpCoord(Minv, (epp[0], epp[1])))
|
||||
poly.append(warpCoord(Minv, (epp[2], epp[3])))
|
||||
for p in reversed(new_pp):
|
||||
poly.append(warpCoord(Minv, (p[2], p[3])))
|
||||
poly.append(warpCoord(Minv, (spp[2], spp[3])))
|
||||
|
||||
# add to final result
|
||||
polys.append(np.array(poly))
|
||||
|
||||
return polys
|
||||
|
||||
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
|
||||
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
|
||||
|
||||
if poly:
|
||||
polys = getPoly_core(boxes, labels, mapper, linkmap)
|
||||
else:
|
||||
polys = [None] * len(boxes)
|
||||
|
||||
return boxes, polys
|
||||
|
||||
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
|
||||
if len(polys) > 0:
|
||||
polys = np.array(polys)
|
||||
for k in range(len(polys)):
|
||||
if polys[k] is not None:
|
||||
polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
|
||||
return polys
|
||||
|
||||
def save_outputs(image, region_scores, affinity_scores, text_threshold, link_threshold,
|
||||
low_text, outoput_path, confidence_mask = None):
|
||||
"""save image, region_scores, and affinity_scores in a single image. region_scores and affinity_scores must be
|
||||
cpu numpy arrays. You can convert GPU Tensors to CPU numpy arrays like this:
|
||||
>>> array = tensor.cpu().data.numpy()
|
||||
When saving outputs of the network during training, make sure you convert ALL tensors (image, region_score,
|
||||
affinity_score) to numpy array first.
|
||||
:param image: numpy array
|
||||
:param region_scores: [] 2D numpy array with each element between 0~1.
|
||||
:param affinity_scores: same as region_scores
|
||||
:param text_threshold: 0 ~ 1. Closer to 0, characters with lower confidence will also be considered a word and be boxed
|
||||
:param link_threshold: 0 ~ 1. Closer to 0, links with lower confidence will also be considered a word and be boxed
|
||||
:param low_text: 0 ~ 1. Closer to 0, boxes will be more loosely drawn.
|
||||
:param outoput_path:
|
||||
:param confidence_mask:
|
||||
:return:
|
||||
"""
|
||||
|
||||
assert region_scores.shape == affinity_scores.shape
|
||||
assert len(image.shape) - 1 == len(region_scores.shape)
|
||||
|
||||
boxes, polys = getDetBoxes(region_scores, affinity_scores, text_threshold, link_threshold,
|
||||
low_text, False)
|
||||
boxes = np.array(boxes, np.int32) * 2
|
||||
if len(boxes) > 0:
|
||||
np.clip(boxes[:, :, 0], 0, image.shape[1])
|
||||
np.clip(boxes[:, :, 1], 0, image.shape[0])
|
||||
for box in boxes:
|
||||
cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
|
||||
|
||||
target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
|
||||
target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
|
||||
|
||||
if confidence_mask is not None:
|
||||
confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
|
||||
gt_scores = np.hstack([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color])
|
||||
confidence_mask_gray = np.hstack([np.zeros_like(confidence_mask_gray), confidence_mask_gray])
|
||||
output = np.concatenate([gt_scores, confidence_mask_gray], axis=0)
|
||||
output = np.hstack([image, output])
|
||||
|
||||
else:
|
||||
gt_scores = np.concatenate([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color], axis=0)
|
||||
output = np.hstack([image, gt_scores])
|
||||
|
||||
cv2.imwrite(outoput_path, output)
|
||||
return output
|
||||
|
||||
|
||||
def save_outputs_from_tensors(images, region_scores, affinity_scores, text_threshold, link_threshold,
|
||||
low_text, output_dir, image_names, confidence_mask = None):
|
||||
|
||||
"""takes images, region_scores, and affinity_scores as tensors (cab be GPU).
|
||||
:param images: 4D tensor
|
||||
:param region_scores: 3D tensor with values between 0 ~ 1
|
||||
:param affinity_scores: 3D tensor with values between 0 ~ 1
|
||||
:param text_threshold:
|
||||
:param link_threshold:
|
||||
:param low_text:
|
||||
:param output_dir: direcotry to save the output images. Will be joined with base names of image_names
|
||||
:param image_names: names of each image. Doesn't have to be the base name (image file names)
|
||||
:param confidence_mask:
|
||||
:return:
|
||||
"""
|
||||
#import ipdb;ipdb.set_trace()
|
||||
#images = images.cpu().permute(0, 2, 3, 1).contiguous().data.numpy()
|
||||
if type(images) == torch.Tensor:
|
||||
images = np.array(images)
|
||||
|
||||
region_scores = region_scores.cpu().data.numpy()
|
||||
affinity_scores = affinity_scores.cpu().data.numpy()
|
||||
|
||||
batch_size = images.shape[0]
|
||||
assert batch_size == region_scores.shape[0] and batch_size == affinity_scores.shape[0] and batch_size == len(image_names), \
|
||||
"The first dimension (i.e. batch size) of images, region scores, and affinity scores must be equal"
|
||||
|
||||
output_images = []
|
||||
|
||||
for i in range(batch_size):
|
||||
image = images[i]
|
||||
region_score = region_scores[i]
|
||||
affinity_score = affinity_scores[i]
|
||||
|
||||
image_name = os.path.basename(image_names[i])
|
||||
outoput_path = os.path.join(output_dir,image_name)
|
||||
|
||||
output_image = save_outputs(image, region_score, affinity_score, text_threshold, link_threshold,
|
||||
low_text, outoput_path, confidence_mask=confidence_mask)
|
||||
|
||||
output_images.append(output_image)
|
||||
|
||||
return output_images
|
||||
361
trainer/craft/utils/inference_boxes.py
Normal file
361
trainer/craft/utils/inference_boxes.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import os
|
||||
import re
|
||||
import itertools
|
||||
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
from utils.craft_utils import getDetBoxes, adjustResultCoordinates
|
||||
from data import imgproc
|
||||
from data.dataset import SynthTextDataSet
|
||||
import math
|
||||
import xml.etree.ElementTree as elemTree
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------------------------------------------#
|
||||
def rotatePoint(xc, yc, xp, yp, theta):
|
||||
xoff = xp - xc
|
||||
yoff = yp - yc
|
||||
|
||||
cosTheta = math.cos(theta)
|
||||
sinTheta = math.sin(theta)
|
||||
pResx = cosTheta * xoff + sinTheta * yoff
|
||||
pResy = - sinTheta * xoff + cosTheta * yoff
|
||||
# pRes = (xc + pResx, yc + pResy)
|
||||
return int(xc + pResx), int(yc + pResy)
|
||||
|
||||
def addRotatedShape(cx, cy, w, h, angle):
|
||||
p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
|
||||
p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
|
||||
p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
|
||||
p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
|
||||
|
||||
points = [[p0x, p0y], [p1x, p1y], [p2x, p2y], [p3x, p3y]]
|
||||
|
||||
return points
|
||||
|
||||
def xml_parsing(xml):
|
||||
tree = elemTree.parse(xml)
|
||||
|
||||
annotations = [] # Initialize the list to store labels
|
||||
iter_element = tree.iter(tag="object")
|
||||
|
||||
for element in iter_element:
|
||||
annotation = {} # Initialize the dict to store labels
|
||||
|
||||
annotation['name'] = element.find("name").text # Save the name tag value
|
||||
|
||||
box_coords = element.iter(tag="robndbox")
|
||||
|
||||
for box_coord in box_coords:
|
||||
cx = float(box_coord.find("cx").text)
|
||||
cy = float(box_coord.find("cy").text)
|
||||
w = float(box_coord.find("w").text)
|
||||
h = float(box_coord.find("h").text)
|
||||
angle = float(box_coord.find("angle").text)
|
||||
|
||||
convertcoodi = addRotatedShape(cx, cy, w, h, angle)
|
||||
|
||||
annotation['box_coodi'] = convertcoodi
|
||||
annotations.append(annotation)
|
||||
|
||||
box_coords = element.iter(tag="bndbox")
|
||||
|
||||
for box_coord in box_coords:
|
||||
xmin = int(box_coord.find("xmin").text)
|
||||
ymin = int(box_coord.find("ymin").text)
|
||||
xmax = int(box_coord.find("xmax").text)
|
||||
ymax = int(box_coord.find("ymax").text)
|
||||
# annotation['bndbox'] = [xmin,ymin,xmax,ymax]
|
||||
|
||||
annotation['box_coodi'] = [[xmin, ymin], [xmax, ymin], [xmax, ymax],
|
||||
[xmin, ymax]]
|
||||
annotations.append(annotation)
|
||||
|
||||
|
||||
|
||||
|
||||
bounds = []
|
||||
for i in range(len(annotations)):
|
||||
box_info_dict = {"points": None, "text": None, "ignore": None}
|
||||
|
||||
box_info_dict["points"] = np.array(annotations[i]['box_coodi'])
|
||||
if annotations[i]['name'] == "dnc":
|
||||
box_info_dict["text"] = "###"
|
||||
box_info_dict["ignore"] = True
|
||||
else:
|
||||
box_info_dict["text"] = annotations[i]['name']
|
||||
box_info_dict["ignore"] = False
|
||||
|
||||
bounds.append(box_info_dict)
|
||||
|
||||
|
||||
|
||||
return bounds
|
||||
|
||||
#-------------------------------------------------------------------------------------------------------------------#
|
||||
|
||||
def load_prescription_gt(dataFolder):
|
||||
|
||||
|
||||
total_img_path = []
|
||||
total_imgs_bboxes = []
|
||||
for (root, directories, files) in os.walk(dataFolder):
|
||||
for file in files:
|
||||
if '.jpg' in file:
|
||||
img_path = os.path.join(root, file)
|
||||
total_img_path.append(img_path)
|
||||
if '.xml' in file:
|
||||
gt_path = os.path.join(root, file)
|
||||
total_imgs_bboxes.append(gt_path)
|
||||
|
||||
|
||||
total_imgs_parsing_bboxes = []
|
||||
for img_path, bbox in zip(sorted(total_img_path), sorted(total_imgs_bboxes)):
|
||||
# check file
|
||||
|
||||
assert img_path.split(".jpg")[0] == bbox.split(".xml")[0]
|
||||
|
||||
result_label = xml_parsing(bbox)
|
||||
total_imgs_parsing_bboxes.append(result_label)
|
||||
|
||||
|
||||
return total_imgs_parsing_bboxes, sorted(total_img_path)
|
||||
|
||||
|
||||
# NOTE
|
||||
def load_prescription_cleval_gt(dataFolder):
|
||||
|
||||
|
||||
total_img_path = []
|
||||
total_gt_path = []
|
||||
for (root, directories, files) in os.walk(dataFolder):
|
||||
for file in files:
|
||||
if '.jpg' in file:
|
||||
img_path = os.path.join(root, file)
|
||||
total_img_path.append(img_path)
|
||||
if '_cl.txt' in file:
|
||||
gt_path = os.path.join(root, file)
|
||||
total_gt_path.append(gt_path)
|
||||
|
||||
|
||||
total_imgs_parsing_bboxes = []
|
||||
for img_path, gt_path in zip(sorted(total_img_path), sorted(total_gt_path)):
|
||||
# check file
|
||||
|
||||
assert img_path.split(".jpg")[0] == gt_path.split('_label_cl.txt')[0]
|
||||
|
||||
lines = open(gt_path, encoding="utf-8").readlines()
|
||||
word_bboxes = []
|
||||
|
||||
for line in lines:
|
||||
box_info_dict = {"points": None, "text": None, "ignore": None}
|
||||
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
||||
|
||||
box_points = [int(box_info[i]) for i in range(8)]
|
||||
box_info_dict["points"] = np.array(box_points)
|
||||
|
||||
word_bboxes.append(box_info_dict)
|
||||
total_imgs_parsing_bboxes.append(word_bboxes)
|
||||
|
||||
return total_imgs_parsing_bboxes, sorted(total_img_path)
|
||||
|
||||
|
||||
def load_synthtext_gt(data_folder):
|
||||
|
||||
synth_dataset = SynthTextDataSet(
|
||||
output_size=768, data_dir=data_folder, saved_gt_dir=data_folder, logging=False
|
||||
)
|
||||
img_names, img_bbox, img_words = synth_dataset.load_data(bbox="word")
|
||||
|
||||
total_img_path = []
|
||||
total_imgs_bboxes = []
|
||||
for index in range(len(img_bbox[:100])):
|
||||
img_path = os.path.join(data_folder, img_names[index][0])
|
||||
total_img_path.append(img_path)
|
||||
try:
|
||||
wordbox = img_bbox[index].transpose((2, 1, 0))
|
||||
except:
|
||||
wordbox = np.expand_dims(img_bbox[index], axis=0)
|
||||
wordbox = wordbox.transpose((0, 2, 1))
|
||||
|
||||
words = [re.split(" \n|\n |\n| ", t.strip()) for t in img_words[index]]
|
||||
words = list(itertools.chain(*words))
|
||||
words = [t for t in words if len(t) > 0]
|
||||
|
||||
if len(words) != len(wordbox):
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
single_img_bboxes = []
|
||||
for j in range(len(words)):
|
||||
box_info_dict = {"points": None, "text": None, "ignore": None}
|
||||
box_info_dict["points"] = wordbox[j]
|
||||
box_info_dict["text"] = words[j]
|
||||
box_info_dict["ignore"] = False
|
||||
single_img_bboxes.append(box_info_dict)
|
||||
|
||||
total_imgs_bboxes.append(single_img_bboxes)
|
||||
|
||||
return total_imgs_bboxes, total_img_path
|
||||
|
||||
|
||||
def load_icdar2015_gt(dataFolder, isTraing=False):
|
||||
if isTraing:
|
||||
img_folderName = "ch4_training_images"
|
||||
gt_folderName = "ch4_training_localization_transcription_gt"
|
||||
else:
|
||||
img_folderName = "ch4_test_images"
|
||||
gt_folderName = "ch4_test_localization_transcription_gt"
|
||||
|
||||
gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
|
||||
total_imgs_bboxes = []
|
||||
total_img_path = []
|
||||
for gt_path in gt_folder_path:
|
||||
gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
|
||||
img_path = (
|
||||
gt_path.replace(gt_folderName, img_folderName)
|
||||
.replace(".txt", ".jpg")
|
||||
.replace("gt_", "")
|
||||
)
|
||||
image = cv2.imread(img_path)
|
||||
lines = open(gt_path, encoding="utf-8").readlines()
|
||||
single_img_bboxes = []
|
||||
for line in lines:
|
||||
box_info_dict = {"points": None, "text": None, "ignore": None}
|
||||
|
||||
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
||||
box_points = [int(box_info[j]) for j in range(8)]
|
||||
word = box_info[8:]
|
||||
word = ",".join(word)
|
||||
box_points = np.array(box_points, np.int32).reshape(4, 2)
|
||||
cv2.polylines(
|
||||
image, [np.array(box_points).astype(np.int)], True, (0, 0, 255), 1
|
||||
)
|
||||
box_info_dict["points"] = box_points
|
||||
box_info_dict["text"] = word
|
||||
if word == "###":
|
||||
box_info_dict["ignore"] = True
|
||||
else:
|
||||
box_info_dict["ignore"] = False
|
||||
|
||||
single_img_bboxes.append(box_info_dict)
|
||||
total_imgs_bboxes.append(single_img_bboxes)
|
||||
total_img_path.append(img_path)
|
||||
return total_imgs_bboxes, total_img_path
|
||||
|
||||
|
||||
def load_icdar2013_gt(dataFolder, isTraing=False):
|
||||
|
||||
# choose test dataset
|
||||
if isTraing:
|
||||
img_folderName = "Challenge2_Test_Task12_Images"
|
||||
gt_folderName = "Challenge2_Test_Task1_GT"
|
||||
else:
|
||||
img_folderName = "Challenge2_Test_Task12_Images"
|
||||
gt_folderName = "Challenge2_Test_Task1_GT"
|
||||
|
||||
gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
|
||||
|
||||
total_imgs_bboxes = []
|
||||
total_img_path = []
|
||||
for gt_path in gt_folder_path:
|
||||
gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
|
||||
img_path = (
|
||||
gt_path.replace(gt_folderName, img_folderName)
|
||||
.replace(".txt", ".jpg")
|
||||
.replace("gt_", "")
|
||||
)
|
||||
image = cv2.imread(img_path)
|
||||
lines = open(gt_path, encoding="utf-8").readlines()
|
||||
single_img_bboxes = []
|
||||
for line in lines:
|
||||
box_info_dict = {"points": None, "text": None, "ignore": None}
|
||||
|
||||
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
||||
box = [int(box_info[j]) for j in range(4)]
|
||||
word = box_info[4:]
|
||||
word = ",".join(word)
|
||||
box = [
|
||||
[box[0], box[1]],
|
||||
[box[2], box[1]],
|
||||
[box[2], box[3]],
|
||||
[box[0], box[3]],
|
||||
]
|
||||
|
||||
box_info_dict["points"] = box
|
||||
box_info_dict["text"] = word
|
||||
if word == "###":
|
||||
box_info_dict["ignore"] = True
|
||||
else:
|
||||
box_info_dict["ignore"] = False
|
||||
|
||||
single_img_bboxes.append(box_info_dict)
|
||||
|
||||
total_imgs_bboxes.append(single_img_bboxes)
|
||||
total_img_path.append(img_path)
|
||||
|
||||
return total_imgs_bboxes, total_img_path
|
||||
|
||||
|
||||
def test_net(
|
||||
net,
|
||||
image,
|
||||
text_threshold,
|
||||
link_threshold,
|
||||
low_text,
|
||||
cuda,
|
||||
poly,
|
||||
canvas_size=1280,
|
||||
mag_ratio=1.5,
|
||||
):
|
||||
# resize
|
||||
|
||||
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
|
||||
image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
|
||||
)
|
||||
ratio_h = ratio_w = 1 / target_ratio
|
||||
|
||||
# preprocessing
|
||||
x = imgproc.normalizeMeanVariance(img_resized)
|
||||
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
|
||||
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
|
||||
if cuda:
|
||||
x = x.cuda()
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
y, feature = net(x)
|
||||
|
||||
# make score and link map
|
||||
score_text = y[0, :, :, 0].cpu().data.numpy().astype(np.float32)
|
||||
score_link = y[0, :, :, 1].cpu().data.numpy().astype(np.float32)
|
||||
|
||||
# NOTE
|
||||
score_text = score_text[: size_heatmap[0], : size_heatmap[1]]
|
||||
score_link = score_link[: size_heatmap[0], : size_heatmap[1]]
|
||||
|
||||
# Post-processing
|
||||
boxes, polys = getDetBoxes(
|
||||
score_text, score_link, text_threshold, link_threshold, low_text, poly
|
||||
)
|
||||
|
||||
# coordinate adjustment
|
||||
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
||||
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
|
||||
for k in range(len(polys)):
|
||||
if polys[k] is None:
|
||||
polys[k] = boxes[k]
|
||||
|
||||
# render results (optional)
|
||||
score_text = score_text.copy()
|
||||
render_score_text = imgproc.cvt2HeatmapImg(score_text)
|
||||
render_score_link = imgproc.cvt2HeatmapImg(score_link)
|
||||
render_img = [render_score_text, render_score_link]
|
||||
# ret_score_text = imgproc.cvt2HeatmapImg(render_img)
|
||||
|
||||
return boxes, polys, render_img
|
||||
142
trainer/craft/utils/util.py
Normal file
142
trainer/craft/utils/util.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from data import imgproc
|
||||
from utils import craft_utils
|
||||
|
||||
|
||||
def copyStateDict(state_dict):
|
||||
if list(state_dict.keys())[0].startswith("module"):
|
||||
start_idx = 1
|
||||
else:
|
||||
start_idx = 0
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = ".".join(k.split(".")[start_idx:])
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def saveInput(
|
||||
imagename, vis_dir, image, region_scores, affinity_scores, confidence_mask
|
||||
):
|
||||
image = np.uint8(image.copy())
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
boxes, polys = craft_utils.getDetBoxes(
|
||||
region_scores, affinity_scores, 0.85, 0.2, 0.5, False
|
||||
)
|
||||
|
||||
if image.shape[0] / region_scores.shape[0] >= 2:
|
||||
boxes = np.array(boxes, np.int32) * 2
|
||||
else:
|
||||
boxes = np.array(boxes, np.int32)
|
||||
|
||||
if len(boxes) > 0:
|
||||
np.clip(boxes[:, :, 0], 0, image.shape[1])
|
||||
np.clip(boxes[:, :, 1], 0, image.shape[0])
|
||||
for box in boxes:
|
||||
cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
|
||||
target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
|
||||
target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
|
||||
confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
|
||||
|
||||
# overlay
|
||||
height, width, channel = image.shape
|
||||
overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
|
||||
overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
|
||||
confidence_mask_gray = cv2.resize(
|
||||
confidence_mask_gray, (width, height), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
overlay_region = cv2.addWeighted(image, 0.4, overlay_region, 0.6, 5)
|
||||
overlay_aff = cv2.addWeighted(image, 0.4, overlay_aff, 0.7, 6)
|
||||
|
||||
gt_scores = np.concatenate([overlay_region, overlay_aff], axis=1)
|
||||
|
||||
output = np.concatenate([gt_scores, confidence_mask_gray], axis=1)
|
||||
|
||||
output = np.hstack([image, output])
|
||||
|
||||
# synthtext
|
||||
if type(imagename) is not str:
|
||||
imagename = imagename[0].split("/")[-1][:-4]
|
||||
|
||||
outpath = vis_dir + f"/{imagename}_input.jpg"
|
||||
if not os.path.exists(os.path.dirname(outpath)):
|
||||
os.makedirs(os.path.dirname(outpath), exist_ok=True)
|
||||
cv2.imwrite(outpath, output)
|
||||
# print(f'Logging train input into {outpath}')
|
||||
|
||||
|
||||
def saveImage(
|
||||
imagename,
|
||||
vis_dir,
|
||||
image,
|
||||
bboxes,
|
||||
affi_bboxes,
|
||||
region_scores,
|
||||
affinity_scores,
|
||||
confidence_mask,
|
||||
):
|
||||
output_image = np.uint8(image.copy())
|
||||
output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
|
||||
if len(bboxes) > 0:
|
||||
for i in range(len(bboxes)):
|
||||
_bboxes = np.int32(bboxes[i])
|
||||
for j in range(_bboxes.shape[0]):
|
||||
cv2.polylines(
|
||||
output_image,
|
||||
[np.reshape(_bboxes[j], (-1, 1, 2))],
|
||||
True,
|
||||
(0, 0, 255),
|
||||
)
|
||||
|
||||
for i in range(len(affi_bboxes)):
|
||||
cv2.polylines(
|
||||
output_image,
|
||||
[np.reshape(affi_bboxes[i].astype(np.int32), (-1, 1, 2))],
|
||||
True,
|
||||
(255, 0, 0),
|
||||
)
|
||||
|
||||
target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
|
||||
target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
|
||||
confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
|
||||
|
||||
# overlay
|
||||
height, width, channel = image.shape
|
||||
overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
|
||||
overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
|
||||
|
||||
overlay_region = cv2.addWeighted(image.copy(), 0.4, overlay_region, 0.6, 5)
|
||||
overlay_aff = cv2.addWeighted(image.copy(), 0.4, overlay_aff, 0.6, 5)
|
||||
|
||||
heat_map = np.concatenate([overlay_region, overlay_aff], axis=1)
|
||||
|
||||
# synthtext
|
||||
if type(imagename) is not str:
|
||||
imagename = imagename[0].split("/")[-1][:-4]
|
||||
|
||||
output = np.concatenate([output_image, heat_map, confidence_mask_gray], axis=1)
|
||||
outpath = vis_dir + f"/{imagename}.jpg"
|
||||
if not os.path.exists(os.path.dirname(outpath)):
|
||||
os.makedirs(os.path.dirname(outpath), exist_ok=True)
|
||||
|
||||
cv2.imwrite(outpath, output)
|
||||
# print(f'Logging original image into {outpath}')
|
||||
|
||||
|
||||
def save_parser(args):
|
||||
|
||||
""" final options """
|
||||
with open(f"{args.results_dir}/opt.txt", "a", encoding="utf-8") as opt_file:
|
||||
opt_log = "------------ Options -------------\n"
|
||||
arg = vars(args)
|
||||
for k, v in arg.items():
|
||||
opt_log += f"{str(k)}: {str(v)}\n"
|
||||
opt_log += "---------------------------------------\n"
|
||||
print(opt_log)
|
||||
opt_file.write(opt_log)
|
||||
Reference in New Issue
Block a user