首先要了解
在多分类中,必须知道one-hot编码。独热编码即 One-Hot 编码,又称一位有效编码。其方法是使用 N位 状态寄存器来对 N个状态 进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。
举例
假设我们有四个样本(行),每个样本有三个特征(列),如图:
import torch.nn.functional as F
import torch
num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = F.one_hot(label, num_classes=5 )
print(one_hot)
"""
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0]])
"""
基于numpy的one-hot代码
import numpy as np
# 设置类别的数量
num_classes = 10
# 需要转换的整数
arr = [1, 3, 4, 5, 9]
# 将整数转为一个10位的one hot编码
print(np.eye(10)[arr])
多分类问题一般用softmax作为神经网络的最后一层,然后计算交叉熵损失。
softmax函数的作用是将每个类别所对应的输出分量归一化,使各个分量的和为1。可以理解为将每个输出分量转化为对应的概率。
计算公式:
关于样本集的两个概率分布p和q,设p为真实的分布,比如[1,0,0]表示样本属于第一类,q为预测的概率分布,比如[0.7,0.2,0.1]
按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度(信息熵):
pytorch提供了两个类来计算交叉熵,分别是CrossEntropyLoss() 和NLLLoss()。
对于torch.nn.CrossEntropyLoss()定义如下:
torch.nn.CrossEntropyLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)
表示一个样本的非softmax输出(网络输出的预测标签不需要经过softmax,因为torch.nn.CrossEntropyLoss已经自带softmax),c表示该样本的标签,则损失函数公式描述如下,
如果weight被指定,
其中,
import torch
import torch.nn as nn
model = nn.Linear(10, 3)
criterion = nn.CrossEntropyLoss()
x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
logits = model(x) # (16, 3)
loss = criterion(logits, y)
对于torch.nn.NLLLoss()定义如下:
torch.nn.NLLLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)
表示一个样本对每个类别的对数似然(log-probabilities),c表示该样本的标签,损失函数公式描述如下:
其中,
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 3),
nn.LogSoftmax()
)
criterion = nn.NLLLoss()
x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
out = model(x) # (16, 3)
loss = criterion(out, y)
1.CrossEntropyLoss() 和NLLLoss()的使用是不一样的,使用CrossEntropyLoss() 的网络架构不需要最后一层加入softmax,而NLLLoss()则需要加入softmax.
2.在使用CrossEntropyLoss() 和NLLLoss()的时候,不需要进行one_hot编码,函数会自动处理
3.由于不需要进行one_hot编码,由于网络预测结果的维度会比label的维度多1.当预测的维度大于2时,第二维度是label的个数(由损失函数的源码可知).
参考链接:
因篇幅问题不能全部显示,请点此查看更多更全内容