113 lines
4.7 KiB
Python
113 lines
4.7 KiB
Python
import os
|
|
import time
|
|
import string
|
|
import argparse
|
|
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.utils.data
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from nltk.metrics.distance import edit_distance
|
|
|
|
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
|
from dataset import hierarchical_dataset, AlignCollate
|
|
from model import Model
|
|
|
|
def validation(model, criterion, evaluation_loader, converter, opt, device):
|
|
""" validation or evaluation """
|
|
n_correct = 0
|
|
norm_ED = 0
|
|
length_of_data = 0
|
|
infer_time = 0
|
|
valid_loss_avg = Averager()
|
|
|
|
for i, (image_tensors, labels) in enumerate(evaluation_loader):
|
|
batch_size = image_tensors.size(0)
|
|
length_of_data = length_of_data + batch_size
|
|
image = image_tensors.to(device)
|
|
# For max length prediction
|
|
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
|
|
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
|
|
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
|
|
|
start_time = time.time()
|
|
if 'CTC' in opt.Prediction:
|
|
preds = model(image, text_for_pred)
|
|
# print(f"preds shape : {preds.shape}")
|
|
forward_time = time.time() - start_time
|
|
|
|
# Calculate evaluation loss for CTC decoder.
|
|
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
|
# permute 'preds' to use CTCloss format
|
|
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
|
|
|
if opt.decode == 'greedy':
|
|
# Select max probabilty (greedy decoding) then decode index to character
|
|
_, preds_index = preds.max(2)
|
|
preds_index = preds_index.view(-1)
|
|
preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
|
|
elif opt.decode == 'beamsearch':
|
|
preds_str = converter.decode_beamsearch(preds, beamWidth=2)
|
|
|
|
else:
|
|
preds = model(image, text_for_pred, is_train=False)
|
|
forward_time = time.time() - start_time
|
|
|
|
preds = preds[:, :text_for_loss.shape[1] - 1, :]
|
|
target = text_for_loss[:, 1:] # without [GO] Symbol
|
|
cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
|
|
|
# select max probabilty (greedy decoding) then decode index to character
|
|
_, preds_index = preds.max(2)
|
|
preds_str = converter.decode(preds_index, length_for_pred)
|
|
labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
|
|
|
|
infer_time += forward_time
|
|
valid_loss_avg.add(cost)
|
|
|
|
# calculate accuracy & confidence score
|
|
preds_prob = F.softmax(preds, dim=2)
|
|
preds_max_prob, _ = preds_prob.max(dim=2)
|
|
confidence_score_list = []
|
|
|
|
for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
|
|
if 'Attn' in opt.Prediction:
|
|
gt = gt[:gt.find('[s]')]
|
|
pred_EOS = pred.find('[s]')
|
|
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
|
|
pred_max_prob = pred_max_prob[:pred_EOS]
|
|
|
|
if pred == gt:
|
|
n_correct += 1
|
|
|
|
'''
|
|
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
|
|
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
|
|
if len(gt) == 0:
|
|
norm_ED += 1
|
|
else:
|
|
norm_ED += edit_distance(pred, gt) / len(gt)
|
|
'''
|
|
|
|
# ICDAR2019 Normalized Edit Distance
|
|
if len(gt) == 0 or len(pred) ==0:
|
|
norm_ED += 0
|
|
elif len(gt) > len(pred):
|
|
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
|
|
else:
|
|
norm_ED += 1 - edit_distance(pred, gt) / len(pred)
|
|
|
|
# calculate confidence score (= multiply of pred_max_prob)
|
|
try:
|
|
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
|
|
except:
|
|
confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
|
|
confidence_score_list.append(confidence_score)
|
|
# print(pred, gt, pred==gt, confidence_score)
|
|
|
|
accuracy = n_correct / float(length_of_data) * 100
|
|
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
|
|
|
|
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
|