Files
easyocr/trainer/craft/loss/mseloss.py

173 lines
5.6 KiB
Python
Raw Permalink Normal View History

2025-07-10 19:42:57 +08:00
import torch
import torch.nn as nn
class Loss(nn.Module):
def __init__(self):
super(Loss, self).__init__()
def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map):
loss = torch.mean(
((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2))
* conf_map
)
return loss
class Maploss_v2(nn.Module):
def __init__(self):
super(Maploss_v2, self).__init__()
def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg):
# positive_loss
positive_pixel = (label_score > 0.1).float()
positive_pixel_number = torch.sum(positive_pixel)
positive_loss_region = pred_score * positive_pixel
# negative_loss
negative_pixel = (label_score <= 0.1).float()
negative_pixel_number = torch.sum(negative_pixel)
negative_loss_region = pred_score * negative_pixel
if positive_pixel_number != 0:
if negative_pixel_number < neg_rto * positive_pixel_number:
negative_loss = (
torch.sum(
torch.topk(
negative_loss_region.view(-1), n_min_neg, sorted=False
)[0]
)
/ n_min_neg
)
else:
negative_loss = torch.sum(
torch.topk(
negative_loss_region.view(-1),
int(neg_rto * positive_pixel_number),
sorted=False,
)[0]
) / (positive_pixel_number * neg_rto)
positive_loss = torch.sum(positive_loss_region) / positive_pixel_number
else:
# only negative pixel
negative_loss = (
torch.sum(
torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[
0
]
)
/ n_min_neg
)
positive_loss = 0.0
total_loss = positive_loss + negative_loss
return total_loss
def forward(
self,
region_scores_label,
affinity_socres_label,
region_scores_pre,
affinity_scores_pre,
mask,
neg_rto,
n_min_neg,
):
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
assert (
region_scores_label.size() == region_scores_pre.size()
and affinity_socres_label.size() == affinity_scores_pre.size()
)
loss1 = loss_fn(region_scores_pre, region_scores_label)
loss2 = loss_fn(affinity_scores_pre, affinity_socres_label)
loss_region = torch.mul(loss1, mask)
loss_affinity = torch.mul(loss2, mask)
char_loss = self.batch_image_loss(
loss_region, region_scores_label, neg_rto, n_min_neg
)
affi_loss = self.batch_image_loss(
loss_affinity, affinity_socres_label, neg_rto, n_min_neg
)
return char_loss + affi_loss
class Maploss_v3(nn.Module):
def __init__(self):
super(Maploss_v3, self).__init__()
def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg):
batch_size = pre_loss.shape[0]
positive_loss, negative_loss = 0, 0
for single_loss, single_label in zip(pre_loss, loss_label):
# positive_loss
pos_pixel = (single_label >= 0.1).float()
n_pos_pixel = torch.sum(pos_pixel)
pos_loss_region = single_loss * pos_pixel
positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12)
# negative_loss
neg_pixel = (single_label < 0.1).float()
n_neg_pixel = torch.sum(neg_pixel)
neg_loss_region = single_loss * neg_pixel
if n_pos_pixel != 0:
if n_neg_pixel < neg_rto * n_pos_pixel:
negative_loss += torch.sum(neg_loss_region) / n_neg_pixel
else:
n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel)
# n_hard_neg = neg_rto*n_pos_pixel
negative_loss += (
torch.sum(
torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0]
)
/ n_hard_neg
)
else:
# only negative pixel
negative_loss += (
torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0])
/ n_min_neg
)
total_loss = (positive_loss + negative_loss) / batch_size
return total_loss
def forward(
self,
region_scores_label,
affinity_scores_label,
region_scores_pre,
affinity_scores_pre,
mask,
neg_rto,
n_min_neg,
):
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
assert (
region_scores_label.size() == region_scores_pre.size()
and affinity_scores_label.size() == affinity_scores_pre.size()
)
loss1 = loss_fn(region_scores_pre, region_scores_label)
loss2 = loss_fn(affinity_scores_pre, affinity_scores_label)
loss_region = torch.mul(loss1, mask)
loss_affinity = torch.mul(loss2, mask)
char_loss = self.single_image_loss(
loss_region, region_scores_label, neg_rto, n_min_neg
)
affi_loss = self.single_image_loss(
loss_affinity, affinity_scores_label, neg_rto, n_min_neg
)
return char_loss + affi_loss