线性模型 - Softmax 回归(参数学习)

news/2025/2/23 23:25:45

本文,我们来学习Softmax 回归的参数学习,在开始之前,我们先了解一下“损失函数”、“风险函数”和“目标函数”这三个核心概念。

一、损失函数、风险函数、目标函数

1. 损失函数(Loss Function)

  • 定义
    损失函数是用来衡量单个样本预测结果与真实标签之间差异的函数。它描述了对一个样本来说,模型犯错误的“代价”或“惩罚”。

  • 例子

    • 对于回归问题,常用的损失函数有均方误差: 
    • 对于二分类问题,常用的损失函数有二元交叉熵:
  • 作用
    它直接衡量单个样本的预测误差,是后续构建总体模型优化目标的基本组成部分。

2. 风险函数(Risk Function)

  • 定义
    风险函数是对所有可能样本(或整个数据分布)上损失函数的期望值,也称为“期望损失”。在统计学习中,风险函数描述了模型在总体数据分布下的平均表现。

  • 数学表达

  • 作用
    风险函数衡量了模型的泛化能力和整体表现,是理论上模型“好坏”的评判标准。实际中,由于总体数据分布未知,我们往往使用经验风险(在训练集上的平均损失)来近似。

3. 目标函数(Objective Function)

  • 定义
    目标函数是优化算法在训练过程中希望最小化(或最大化)的函数。它通常包含经验风险(或损失函数的平均值),有时还会加入正则化项以控制模型复杂度。

  • 数学表达
    例如,对于一个回归问题,目标函数可以写为

    其中第一项是经验风险(经验损失的平均值),Ω(f) 是正则化项,λ 是正则化参数。

  • 作用
    目标函数定义了训练过程中需要优化的具体指标。通过优化目标函数,我们期望找到一个模型,使得在训练数据上损失(以及模型复杂度)达到平衡,进而提升泛化能力。

4. 三者之间的区别与联系

  • 区别

    • 损失函数:针对单个样本,衡量预测与真实标签的差异。
    • 风险函数:是损失函数在整个数据分布上的期望,描述模型在总体数据上的平均表现。
    • 目标函数:是实际优化时用到的函数,通常是经验风险(训练集上平均损失)加上正则化项,作为训练过程中参数调整的依据。
  • 联系

    • 损失函数是风险函数的基本组成部分,风险函数是损失函数在总体数据分布下的平均值。
    • 在实际训练中,由于总体数据分布未知,我们通常用训练集上的平均损失(经验风险)作为目标函数的一部分,通过最小化目标函数来间接地降低总体风险。

总结

  • 损失函数告诉我们每个样本犯错的代价;
  • 风险函数(期望损失)描述了模型在整个数据分布上的平均表现;
  • 目标函数是实际优化过程中使用的函数,通常包含经验风险和正则化项,用来指导模型参数的学习

这种分层的定义帮助我们从单个样本的误差度量,扩展到整个数据集甚至总体数据分布的模型评估,再到实际训练过程中具体优化的目标,构成了机器学习模型训练的完整理论框架。

二、Softmax 回归的损失函数:交叉熵损失

(一)模型结构

1、模型定义

2、模型表示

(二)交叉熵损失函数

这里提到one-hot编码,独热分布可以参考:线性模型 - 二分类问题的损失函数-CSDN博客

为了让大家直观理解交叉熵损失函数选择的意义,我还是把对数函数的图像放到这里:

这样大家可以比较直观的看到,随着概率P从0~1的变化,对应的损失的变化情况。这里需要大家具备基本的对数函数的知识。

三、参数学习

1、构造似然函数

通过最大似然估计,目标是找到一组参数 {w_k, b_k}使得训练数据的似然最大。取对数后,最大化对数似然等价于最小化交叉熵损失。

2、求梯度

通过计算损失函数关于参数的梯度,利用梯度下降或其变种(如随机梯度下降、Mini-batch SGD、Adam 等)迭代更新参数。

根据计算得到的梯度,采用梯度下降的更新规则(例如,参数更新公式为:

其中 α是学习率。

重复上述计算和更新过程,直至损失函数收敛到较低值,或达到预设的迭代次数,从而得到最优参数。

3、确定最优参数

  1. 优化算法选择

    • 批量梯度下降(BGD):使用全体训练数据计算梯度,收敛稳定但计算成本高。

    • 随机梯度下降(SGD):每次随机选择一个样本更新参数,速度快但波动大。

    • 小批量梯度下降(Mini-batch GD):折中方案,常用批量大小为32~256。

  2. 学习率调整

    • 固定学习:简单但需手动调参(如α=0.01)。

    • 自适应学习:使用Adam、RMSprop等优化器自动调整。

  3. 正则化技术

    • L2正则化:在损失函数中加入,防止权重过大。

    • 早停法(Early Stopping):在验证集损失不再下降时终止训练。

  4. 收敛判定

    • 损失变化阈值:当损失下降幅度小于预设阈值(如10^−5)时停止。

    • 最大迭代次数:设置训练轮次上限(如1000轮)。

四、参数学习举例:

栗子1:

栗子2:

场景:手写数字识别(3个类别,2个特征)

五、为什么要使用交叉熵损失作为损失函数 ?

1、概率解释

交叉熵损失衡量的是模型预测的概率分布与真实分布之间的差异。在二分类或多分类问题中,真实标签通常以独热编码(one-hot encoding)的形式表示,而模型输出的是一个概率分布。交叉熵正好可以量化这两个分布之间的不匹配程度,从而指导模型改进预测。

2、与最大似然估计的一致性

在逻辑回归和 Softmax 回归中,我们假设样本服从伯努利分布或多项分布。最大似然估计(MLE)的目标是最大化数据的似然,而取对数后,MLE 的目标函数就转化为最小化交叉熵损失。这种方法从理论上保证了最优参数的学习

3、良好的数值性质

交叉熵损失通常是凸的(对于线性模型),这使得利用梯度下降等优化方法能够有效地找到全局最优解。同时,当模型预测错误且非常自信时,交叉熵损失会急剧增大,从而促使模型大幅度调整参数。

4、梯度信息丰富

交叉熵损失提供的梯度信息通常比较丰富,尤其是在预测概率偏离真实值较远时,梯度较大,可以帮助模型更快地学习和纠正错误。

5、举例说明

在垃圾邮件检测任务中,假设真实标签 y=1 表示垃圾邮件,而模型预测出邮件为垃圾邮件的概率为 y^。

  • 当邮件确实是垃圾邮件(y=1)且 y^ 很接近1时,交叉熵损失 −log⁡(y^) 很小,表示模型预测正确;
  • 反之,如果邮件为垃圾邮件但模型预测 y^ 较低(例如0.3),则损失 −log⁡(0.3) 会非常大,迫使模型调整参数以提高预测概率。

因此,使用交叉熵损失作为损失函数,可以让模型在训练过程中有效地衡量预测概率与真实标签之间的差距,通过最小化该损失,我们能够以最大似然估计的方式获得最优参数,同时交叉熵损失具有良好的数学性质和梯度信息,有助于稳定高效地进行优化。

五、总结:

步骤关键点
前向传播计算类别得分 → Softmax归一化为概率 → 计算交叉熵损失
反向传播损失对得分的梯度 = 预测概率 - 真实标签 → 链式法则求权重和偏置的梯度
参数更新通过梯度下降法(或变体)更新权重和偏置
正则化添加L2正则化项控制模型复杂度,防止过拟合
最优参数判定结合验证集监控,通过早停法或损失收敛阈值确定训练终止点
  • 参数学习过程
    Softmax 回归的参数学习通过最大似然估计转化为最小化交叉熵损失,然后使用梯度下降等优化算法更新参数,最终得到能够输出合理概率分布的模型。
  • 确定最优参数
    通过不断迭代更新,直到损失函数收敛或达到预定的训练轮数,从而得到在训练数据上表现最优的参数。
  • 直观效果
    模型最终可以将输入 x 映射为每个类别的概率,并通过 argmax 操作输出预测类别。

这种训练过程不仅在数学上严谨,而且在实际应用中非常高效,适用于多类别分类任务。

通过优化交叉熵损失函数,Softmax回归能够有效学习多分类问题的决策边界。实际应用中需注意学习率调整、正则化强度选择及优化算法的适应性。


http://www.niftyadmin.cn/n/5863860.html

相关文章

Linux 内核中的 container_of 宏:以 ipoib_rx_poll_rss 函数为例

在 Linux 内核编程中,container_of 是一个非常实用的宏,主要用于通过结构体的成员指针来获取包含该成员的整个结构体的指针。rx_ring = container_of(napi, struct ipoib_recv_ring, napi); 在代码中就是利用了这个宏,下面我们详细分析它的作用和工作原理。 背景知识 在内…

断开ssh连接程序继续运行

在使用 SSH 远程连接服务器时,我们常希望在断开连接后仍然让程序继续运行,以下是几种常见的方法: 1. 使用 screen 或 tmux screen 和 tmux 是两款非常强大的终端复用工具,它们允许你在后台运行会话,即使断开 SSH 连接…

【Python爬虫(45)】Python爬虫新境界:分布式与大数据框架的融合之旅

【Python爬虫】专栏简介:本专栏是 Python 爬虫领域的集大成之作,共 100 章节。从 Python 基础语法、爬虫入门知识讲起,深入探讨反爬虫、多线程、分布式等进阶技术。以大量实例为支撑,覆盖网页、图片、音频等各类数据爬取&#xff…

创建型模式-Prototype 模式(原型模式)

原型模式 ‌原型模式(Prototype Pattern)是一种创建型设计模式,通过复制现有对象来创建新对象,避免了重复创建对象的开销‌。原型模式的核心在于通过复制现有的实例对象来生成新的实例对象,从而提升效率。‌ 场景假设…

GStreamer源码安装1.24版本

从官网下载 1.24的源码包 https://gitlab.freedesktop.org/gstreamer/gstreamer/-/tree/1.24?ref_typeheads#getting-started ,尝试过使用git clone 的方式,但速度贼慢,就选择了下载源码包的方式安装依赖 sudo apt install libssl-dev g me…

【Leetcode 每日一题】2506. 统计相似字符串对的数目

问题背景 给你一个下标从 0 0 0 开始的字符串数组 w o r d s words words。 如果两个字符串由相同的字符组成,则认为这两个字符串 相似 。 例如,“abca” 和 “cba” 相似,因为它们都由字符 ‘a’、‘b’、‘c’ 组成。然而,“…

八大排序算法(1)插入排序-直接插入排序 和 希尔排序

直接插入排序(Insertion Sort) 直接插入排序是最基本的插入排序算法,工作原理如下: 从第二个元素开始,将其与前面已经排好序的部分进行比较。 找到合适的位置后,将该元素插入到合适的位置,同…

【20250221更新】WebStorm2024.3.3版本安装+使用方法

1、官网下载正版WebStorm,链接如下 Thank you for downloading WebStorm! 2、获取使用教程,给博主留言【压缩包有密码,见下面】 通过百度网盘分享的文件:【2025022… 链接:https://pan.baidu.com/s/1UMMEDKbRwlGcffAhOlwR5g?pw…