我就废话不多说了,直接上代码吧!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # 支持多分类和二分类 class FocalLoss(nn.Module): """ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt) :param num_class: :param alpha: (tensor) 3D or 4D the scalar factor for this criterion :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more focus on hard misclassified example :param smooth: (float,double) smooth value when cross entropy :param balance_index: (int) balance class index, should be specific when alpha is float :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. """ def __init__( self , num_class, alpha = None , gamma = 2 , balance_index = - 1 , smooth = None , size_average = True ): super (FocalLoss, self ).__init__() self .num_class = num_class self .alpha = alpha self .gamma = gamma self .smooth = smooth self .size_average = size_average if self .alpha is None : self .alpha = torch.ones( self .num_class, 1 ) elif isinstance ( self .alpha, ( list , np.ndarray)): assert len ( self .alpha) = = self .num_class self .alpha = torch.FloatTensor(alpha).view( self .num_class, 1 ) self .alpha = self .alpha / self .alpha. sum () elif isinstance ( self .alpha, float ): alpha = torch.ones( self .num_class, 1 ) alpha = alpha * ( 1 - self .alpha) alpha[balance_index] = self .alpha self .alpha = alpha else : raise TypeError( 'Not support alpha type' ) if self .smooth is not None : if self .smooth < 0 or self .smooth > 1.0 : raise ValueError( 'smooth value should be in [0,1]' ) def forward( self , input , target): logit = F.softmax( input , dim = 1 ) if logit.dim() > 2 : # N,C,d1,d2 -> N,C,m (m=d1*d2*...) logit = logit.view(logit.size( 0 ), logit.size( 1 ), - 1 ) logit = logit.permute( 0 , 2 , 1 ).contiguous() logit = logit.view( - 1 , logit.size( - 1 )) target = target.view( - 1 , 1 ) # N = input.size(0) # alpha = torch.ones(N, self.num_class) # alpha = alpha * (1 - self.alpha) # alpha = alpha.scatter_(1, target.long(), self.alpha) epsilon = 1e - 10 alpha = self .alpha if alpha.device ! = input .device: alpha = alpha.to( input .device) idx = target.cpu(). long () one_hot_key = torch.FloatTensor(target.size( 0 ), self .num_class).zero_() one_hot_key = one_hot_key.scatter_( 1 , idx, 1 ) if one_hot_key.device ! = logit.device: one_hot_key = one_hot_key.to(logit.device) if self .smooth: one_hot_key = torch.clamp( one_hot_key, self .smooth, 1.0 - self .smooth) pt = (one_hot_key * logit). sum ( 1 ) + epsilon logpt = pt.log() gamma = self .gamma alpha = alpha[idx] loss = - 1 * alpha * torch. pow (( 1 - pt), gamma) * logpt if self .size_average: loss = loss.mean() else : loss = loss. sum () return loss class BCEFocalLoss(torch.nn.Module): """ 二分类的Focalloss alpha 固定 """ def __init__( self , gamma = 2 , alpha = 0.25 , reduction = 'elementwise_mean' ): super ().__init__() self .gamma = gamma self .alpha = alpha self .reduction = reduction def forward( self , _input, target): pt = torch.sigmoid(_input) alpha = self .alpha loss = - alpha * ( 1 - pt) * * self .gamma * target * torch.log(pt) - \ ( 1 - alpha) * pt * * self .gamma * ( 1 - target) * torch.log( 1 - pt) if self .reduction = = 'elementwise_mean' : loss = torch.mean(loss) elif self .reduction = = 'sum' : loss = torch. sum (loss) return loss |
以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_33278884/article/details/91572173