使用Focal Loss损失函数解决样本不均衡问题 对于少数类别的样本赋予更高的权重。

Published on Aug. 22, 2023, 12:08 p.m.

Focal Loss的定义

理论定义:Focal Loss可以看作是一个损失函数,它使容易分类的样本权重降低,而对难分类的样本权重增加。
Focal Loss基本上是交叉熵损失的延伸。它足够具体地处理阶级不平衡的问题。

https://github.com/yatengLG/Focal-Loss-Pytorch
https://pypi.org/project/focal-loss-torch/

https://github.com/napoler/Focal-Loss-Pytorch

示例

https://github.com/napoler/Focal-Loss-Pytorch/blob/master/Demo.ipynb



<h1>-<em>- coding: utf-8 -</em>-</h1>
<h1>@Author  : LG</h1>
from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):
    def <strong>init</strong>(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)*<em>γ </em>ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重.      当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).<strong>init</strong>()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print("  Focal_loss alpha = {}, 将对每一类权重进行精细化赋值  ".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print("  Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用  ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
<pre><code>    self.gamma = gamma

def forward(self, preds, labels):
    """
    focal_loss损失计算
    :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B 批次, N检测框数, C类别数
    :param labels:  实际类别. size:[B,N] or [B]
    :return:
    """
    # assert preds.dim()==2 and labels.dim()==1
    preds = preds.view(-1,preds.size(-1))
    self.alpha = self.alpha.to(preds.device)
    preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
    preds_softmax = torch.exp(preds_logsoft)    # softmax

    preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 这部分实现nll_loss ( crossempty = log_softmax + nll )
    preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
    self.alpha = self.alpha.gather(0,labels.view(-1))
    loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ

    loss = torch.mul(self.alpha, loss.t())
    if self.size_average:
        loss = loss.mean()
    else:
        loss = loss.sum()
    return loss
</code></pre>

评价指标

=====

计算接收器操作特性曲线下的面积 (ROC AUC)。适用于二元、多标签和多类问题。在多类的情况下,将基于一对一的方法计算值。

https://torchmetrics.readthedocs.io/en/latest/references/modules.html#auroc

二分类示例

from torchmetrics import AUROC
preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
target = torch.tensor([0, 0, 1, 1, 1])
auroc = AUROC(pos_label=1)
auroc(preds, target)
tensor(0.5000)

多分类示例

preds = torch.tensor([[0.90, 0.05, 0.05],
… [0.05, 0.90, 0.05],
… [0.05, 0.05, 0.90],
… [0.85, 0.05, 0.10],
… [0.10, 0.10, 0.80]])
target = torch.tensor([0, 1, 1, 2, 2])
auroc = AUROC(num_classes=3)
auroc(preds, target)
tensor(0.7778)

除此之外还有很多

https://torchmetrics.readthedocs.io/en/latest/references/modules.html#auroc