您的当前位置:首页神经网络多分类的实现总结

神经网络多分类的实现总结

来源:小侦探旅游网

首先要了解

one-hot

在多分类中,必须知道one-hot编码。独热编码即 One-Hot 编码,又称一位有效编码。其方法是使用 N位 状态寄存器来对 N个状态 进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。

举例
假设我们有四个样本(行),每个样本有三个特征(列),如图:

import torch.nn.functional as F
import torch

num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = F.one_hot(label, num_classes=5 )
print(one_hot)
"""
tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0]])
"""

基于numpy的one-hot代码

import numpy as np
# 设置类别的数量
num_classes = 10
# 需要转换的整数
arr = [1, 3, 4, 5, 9]
# 将整数转为一个10位的one hot编码
print(np.eye(10)[arr])

多分类的损失函数

多分类问题一般用softmax作为神经网络的最后一层,然后计算交叉熵损失。

softmax原理

softmax函数的作用是将每个类别所对应的输出分量归一化,使各个分量的和为1。可以理解为将每个输出分量转化为对应的概率。

计算公式:

交叉熵损失函数

交叉熵的原理

关于样本集的两个概率分布p和q,设p为真实的分布,比如[1,0,0]表示样本属于第一类,q为预测的概率分布,比如[0.7,0.2,0.1]

按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度(信息熵):

多分类任务中的交叉熵损失函数


PyTorch中的交叉熵损失函数实现

pytorch提供了两个类来计算交叉熵,分别是CrossEntropyLoss() 和NLLLoss()。
对于torch.nn.CrossEntropyLoss()定义如下:

torch.nn.CrossEntropyLoss(
    weight=None,
    ignore_index=-100,
    reduction="mean",
)

表示一个样本的非softmax输出(网络输出的预测标签不需要经过softmax,因为torch.nn.CrossEntropyLoss已经自带softmax),c表示该样本的标签,则损失函数公式描述如下,


如果weight被指定,

其中,

import torch
import torch.nn as nn

model = nn.Linear(10, 3)
criterion = nn.CrossEntropyLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,))  # (16, )
logits = model(x)  # (16, 3)

loss = criterion(logits, y)

对于torch.nn.NLLLoss()定义如下:

torch.nn.NLLLoss(
    weight=None,
    ignore_index=-100,
    reduction="mean",
)

表示一个样本对每个类别的对数似然(log-probabilities),c表示该样本的标签,损失函数公式描述如下:

其中,

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 3),
    nn.LogSoftmax()
)
criterion = nn.NLLLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,))  # (16, )
out = model(x)  # (16, 3)

loss = criterion(out, y)

心得

1.CrossEntropyLoss() 和NLLLoss()的使用是不一样的,使用CrossEntropyLoss() 的网络架构不需要最后一层加入softmax,而NLLLoss()则需要加入softmax.
2.在使用CrossEntropyLoss() 和NLLLoss()的时候,不需要进行one_hot编码,函数会自动处理
3.由于不需要进行one_hot编码,由于网络预测结果的维度会比label的维度多1.当预测的维度大于2时,第二维度是label的个数(由损失函数的源码可知).
参考链接:

因篇幅问题不能全部显示,请点此查看更多更全内容