173 lines
5.6 KiB
Python
173 lines
5.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Loss(nn.Module):
|
|
def __init__(self):
|
|
super(Loss, self).__init__()
|
|
|
|
def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map):
|
|
loss = torch.mean(
|
|
((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2))
|
|
* conf_map
|
|
)
|
|
return loss
|
|
|
|
|
|
class Maploss_v2(nn.Module):
|
|
def __init__(self):
|
|
|
|
super(Maploss_v2, self).__init__()
|
|
|
|
def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg):
|
|
|
|
# positive_loss
|
|
positive_pixel = (label_score > 0.1).float()
|
|
positive_pixel_number = torch.sum(positive_pixel)
|
|
|
|
positive_loss_region = pred_score * positive_pixel
|
|
|
|
# negative_loss
|
|
negative_pixel = (label_score <= 0.1).float()
|
|
negative_pixel_number = torch.sum(negative_pixel)
|
|
negative_loss_region = pred_score * negative_pixel
|
|
|
|
if positive_pixel_number != 0:
|
|
if negative_pixel_number < neg_rto * positive_pixel_number:
|
|
negative_loss = (
|
|
torch.sum(
|
|
torch.topk(
|
|
negative_loss_region.view(-1), n_min_neg, sorted=False
|
|
)[0]
|
|
)
|
|
/ n_min_neg
|
|
)
|
|
else:
|
|
negative_loss = torch.sum(
|
|
torch.topk(
|
|
negative_loss_region.view(-1),
|
|
int(neg_rto * positive_pixel_number),
|
|
sorted=False,
|
|
)[0]
|
|
) / (positive_pixel_number * neg_rto)
|
|
positive_loss = torch.sum(positive_loss_region) / positive_pixel_number
|
|
else:
|
|
# only negative pixel
|
|
negative_loss = (
|
|
torch.sum(
|
|
torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[
|
|
0
|
|
]
|
|
)
|
|
/ n_min_neg
|
|
)
|
|
positive_loss = 0.0
|
|
total_loss = positive_loss + negative_loss
|
|
return total_loss
|
|
|
|
def forward(
|
|
self,
|
|
region_scores_label,
|
|
affinity_socres_label,
|
|
region_scores_pre,
|
|
affinity_scores_pre,
|
|
mask,
|
|
neg_rto,
|
|
n_min_neg,
|
|
):
|
|
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
|
|
assert (
|
|
region_scores_label.size() == region_scores_pre.size()
|
|
and affinity_socres_label.size() == affinity_scores_pre.size()
|
|
)
|
|
loss1 = loss_fn(region_scores_pre, region_scores_label)
|
|
loss2 = loss_fn(affinity_scores_pre, affinity_socres_label)
|
|
|
|
loss_region = torch.mul(loss1, mask)
|
|
loss_affinity = torch.mul(loss2, mask)
|
|
|
|
char_loss = self.batch_image_loss(
|
|
loss_region, region_scores_label, neg_rto, n_min_neg
|
|
)
|
|
affi_loss = self.batch_image_loss(
|
|
loss_affinity, affinity_socres_label, neg_rto, n_min_neg
|
|
)
|
|
return char_loss + affi_loss
|
|
|
|
|
|
class Maploss_v3(nn.Module):
|
|
def __init__(self):
|
|
|
|
super(Maploss_v3, self).__init__()
|
|
|
|
def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg):
|
|
|
|
batch_size = pre_loss.shape[0]
|
|
|
|
positive_loss, negative_loss = 0, 0
|
|
for single_loss, single_label in zip(pre_loss, loss_label):
|
|
|
|
# positive_loss
|
|
pos_pixel = (single_label >= 0.1).float()
|
|
n_pos_pixel = torch.sum(pos_pixel)
|
|
pos_loss_region = single_loss * pos_pixel
|
|
positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12)
|
|
|
|
# negative_loss
|
|
neg_pixel = (single_label < 0.1).float()
|
|
n_neg_pixel = torch.sum(neg_pixel)
|
|
neg_loss_region = single_loss * neg_pixel
|
|
|
|
if n_pos_pixel != 0:
|
|
if n_neg_pixel < neg_rto * n_pos_pixel:
|
|
negative_loss += torch.sum(neg_loss_region) / n_neg_pixel
|
|
else:
|
|
n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel)
|
|
# n_hard_neg = neg_rto*n_pos_pixel
|
|
negative_loss += (
|
|
torch.sum(
|
|
torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0]
|
|
)
|
|
/ n_hard_neg
|
|
)
|
|
else:
|
|
# only negative pixel
|
|
negative_loss += (
|
|
torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0])
|
|
/ n_min_neg
|
|
)
|
|
|
|
total_loss = (positive_loss + negative_loss) / batch_size
|
|
|
|
return total_loss
|
|
|
|
def forward(
|
|
self,
|
|
region_scores_label,
|
|
affinity_scores_label,
|
|
region_scores_pre,
|
|
affinity_scores_pre,
|
|
mask,
|
|
neg_rto,
|
|
n_min_neg,
|
|
):
|
|
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
|
|
|
|
assert (
|
|
region_scores_label.size() == region_scores_pre.size()
|
|
and affinity_scores_label.size() == affinity_scores_pre.size()
|
|
)
|
|
loss1 = loss_fn(region_scores_pre, region_scores_label)
|
|
loss2 = loss_fn(affinity_scores_pre, affinity_scores_label)
|
|
|
|
loss_region = torch.mul(loss1, mask)
|
|
loss_affinity = torch.mul(loss2, mask)
|
|
char_loss = self.single_image_loss(
|
|
loss_region, region_scores_label, neg_rto, n_min_neg
|
|
)
|
|
affi_loss = self.single_image_loss(
|
|
loss_affinity, affinity_scores_label, neg_rto, n_min_neg
|
|
)
|
|
|
|
return char_loss + affi_loss
|