paddle中nll_loss()與CrossEntropyLoss()損失函式區別
首先先交代結論:nll_loss()與CrossEntropyLoss()損失函式計算的關係為CrossEntropyLoss()等於對輸入資料先做softmax,再做log處理,再加nll_loss()操作。
1.NLLLoss 的輸入是一個對數概率向量和一個目標標籤,它不會為我們計算對數概率.,適合網路的最後一層是log_softmax損失函式,有時候呼叫預訓練模型修改網路最後的全連接層的時候會把最後一層的輸出改變,這時候就不適合用CrossEntropyLoss()作為損失函式;CrossEntropyLoss()與NLLLoss()相同,唯一的不同是它為我們去做 softmax。
2.CrossEntropyLoss():交叉熵損失函式,交叉熵描述了兩個概率分布之間的距離,當交叉熵越小說明二者之間越接近。
以下提供在paddlepaddle2.0以及pytorch兩種框架中的驗證。
一、paddlepaddle2.0框架中的驗證
這一部分我已經上傳我的ai studio上,上面有分步的詳細解釋,歡迎大家fork學習。
https://aistudio.baidu.com/aistudio/projectdetail/1759396
二、pytorch框架中的驗證
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | import torch import torch.nn.functional as nn input = torch.tensor([[ 0.9451, -0.1922, -0.5285], [-0.3981, 0.5784, 0.0945], [-1.1648, 0.1565, -0.4067]]) input_aftersoftmax_hang = nn.softmax(input,dim = 1) # print(input_aftersoftmax_hang) ''' tensor([[0.6453, 0.2069, 0.1478], [0.1890, 0.5018, 0.3093], [0.1453, 0.5446, 0.3101]]) ''' input_aftersoftmax_lie = nn.softmax(input,dim = 0) # print(input_aftersoftmax_lie) ''' tensor([[0.7235, 0.2184, 0.2504], [0.1888, 0.4720, 0.4668], [0.0877, 0.3096, 0.2828]]) ''' input_log_hangsoftmax = torch.log(input_aftersoftmax_hang) # print(input_log_hangsoftmax) ''' tensor([[-0.4381, -1.5754, -1.9117], [-1.6661, -0.6896, -1.1735], [-1.9290, -0.6077, -1.1709]]) ''' input_log = torch.log_softmax(input,dim = 1) # print(input_log) ''' tensor([[-0.4381, -1.5754, -1.9117], [-1.6661, -0.6896, -1.1735], [-1.9290, -0.6077, -1.1709]]) ''' # 由此可以說明對input按行做softmax後再做log處理的結果,和直接用torch.log_softmax()的結果是一樣的 # 假設每一行代表一個資料,他們的真實標籤分別為0,2,1,也就是取input_log第一行第一個,第二行第三個,第三行第二個作為預測結果 target = torch.tensor([0,2,1]) # 計算nnl_loss損失函式的損失值 loss_nlllose = nn.nll_loss(input_log_hangsoftmax, target) # print(loss_nlllose) # tensor(0.7398) # 計算CrossEntropyLoss損失函式損失值 loss_CrossEntropyLoss = nn.cross_entropy(input,target) # print(loss_CrossEntropyLoss) # tensor(0.7398) # 由此可以知道,輸入資料按行做softmax,再做log操作,再接nn.nll_loss,等於nn.cross_entropy(input,target) |