news 2026/6/24 9:15:08

损失函数 的 硬截断 和 平滑衰减

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
损失函数 的 硬截断 和 平滑衰减

损失函数 的 硬截断 和 平滑衰减

flyfish

在逐样本损失计算完成、取平均之前,对损失过高的样本做权重压制,不删除样本,只削弱它们对梯度的贡献,属于软降权——既保留了样本的监督信号,又避免极端难样本/疑似错标样本带偏整个模型。

损失硬截断

损失硬截断是给单样本损失设置一个上限,超过这个阈值的损失,直接按阈值计算。相当于一刀切,超过上限的样本梯度不再放大。

代码实现

classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma=2,alpha=None,smoothing=0.0,num_classes=2,max_loss=None):""" :param max_loss: 单样本损失上限,None表示不开启截断;设置数值后,单样本损失不会超过该值 """super().__init__()self.gamma=gamma self.alpha=torch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothing=smoothing self.num_classes=num_classes self.max_loss=max_loss# 损失截断阈值defforward(self,inputs,targets):targets_one_hot=torch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targets=targets_one_hot*(1-self.smoothing)+self.smoothing/self.num_classes log_probs=torch.nn.functional.log_softmax(inputs,dim=1)probs=torch.exp(log_probs)p_t=(probs*targets_one_hot).sum(dim=1,keepdim=True)focal_weight=(1-p_t)**self.gamma ce_loss=(-soft_targets*log_probs).sum(dim=1)loss=focal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t=(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim=1)loss=loss*alpha_t# ========== 损失截断 ==========ifself.max_lossisnotNone:loss=torch.clamp(loss,max=self.max_loss)returnloss.mean()

使用方式

在训练函数里初始化损失时,多加一个max_loss参数即可:

# 示例:单样本损失最高不超过2.0,超过的全部按2.0计算criterion=FocalLossWithSmoothing(gamma=FOCAL_GAMMA,alpha=FOCAL_ALPHA,smoothing=LABEL_SMOOTHING,num_classes=NUM_CLASSES,max_loss=2.0# 开启截断,阈值可按需调整)

平滑衰减降权

硬截断是一刀切:损失超过阈值,直接砍平,损失值瞬间不再增长,像台阶一样突变;
平滑衰减是越涨越慢:损失低于阈值时正常计算,超过阈值后还能继续涨,但增长速度会越来越慢,过渡是顺滑的曲线,没有突变台阶。

它的目的:既保留损失越高、权重越大的相对顺序,又不让极端高损失样本无限放大梯度带偏模型,同时保证训练过程梯度平稳,不会出现跳变

代码实现 只需要把截断部分替换成平滑衰减逻辑即可:

classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma=2,alpha=None,smoothing=0.0,num_classes=3,loss_threshold=1.8):super().__init__()self.gamma=gamma self.alpha=torch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothing=smoothing self.num_classes=num_classes self.loss_threshold=loss_threshold# 平滑衰减阈值defforward(self,inputs,targets):targets_one_hot=torch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targets=targets_one_hot*(1-self.smoothing)+self.smoothing/self.num_classes log_probs=torch.nn.functional.log_softmax(inputs,dim=1)probs=torch.exp(log_probs)p_t=(probs*targets_one_hot).sum(dim=1,keepdim=True)focal_weight=(1-p_t)**self.gamma ce_loss=(-soft_targets*log_probs).sum(dim=1)loss=focal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t=(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim=1)loss=loss*alpha_t# 平滑衰减降权:压制极端高损失样本ifself.loss_thresholdisnotNone:high_loss_mask=loss>self.loss_threshold loss[high_loss_mask]=self.loss_threshold+torch.log(1+loss[high_loss_mask]-self.loss_threshold)returnloss.mean()

假设设置阈值 = 1.5,看不同原始损失对应的处理结果:

原始单样本损失硬截断后损失变化特点
1.0(正常样本)1.0低于阈值,完全不变
1.4(较难样本)1.4低于阈值,完全不变
1.5(阈值点)1.5刚好等于阈值
1.6(难样本)1.5超过一点点,直接被砍成1.5,瞬间停止增长
3.0(极难/错标样本)1.5不管多高,全砍成1.5,和1.6的样本权重完全一样

硬截断的问题

  1. 阈值点处损失突变,梯度也会突变,训练过程容易出现震荡;
  2. 所有超过阈值的样本,损失都一样,丢失了难分程度的差异信息——3.0的极难样本和1.6的轻微难样本,对模型的贡献变得完全相同,有点矫枉过正。

平滑衰减的逻辑:两段式 + 对数压缩

代码里用的是阈值以下正常计算,阈值以上对数压缩的两段式策略,公式是:
处理后损失={原始损失原始损失≤阈值阈值+log⁡(1+原始损失−阈值)原始损失>阈值 \text{处理后损失} = \begin{cases} \text{原始损失} & \text{原始损失} \le 阈值 \\ 阈值 + \log(1 + \text{原始损失} - 阈值) & \text{原始损失} > 阈值 \end{cases}处理后损失={原始损失阈值+log(1+原始损失阈值)原始损失阈值原始损失>阈值

为什么用 log(对数)函数?

对数函数有两个完美匹配需求的特性:

  1. 单调递增:原始损失越大,处理后的损失也一定越大,不会改变谁更难、谁损失更高的排序,样本的相对权重关系保留了;
  2. 增速递减:x 越大,log(x) 涨得越慢。原始损失越高,压缩力度越强,正好符合极端样本降权更多的需求。

直观对比效果

还是设阈值 = 1.5,算一组真实数值,一眼就能看出区别:

原始单样本损失硬截断后平滑衰减后直观感受
1.01.01.00低于阈值,两者完全一样
1.41.41.40低于阈值,两者完全一样
1.51.51.50阈值点,两者对齐
1.61.51.595只超了一点点,压缩很轻微,几乎和原值差不多
2.01.51.693超了0.5,增长明显放缓,不再是直线涨
3.01.51.946超了1.5,涨幅被大幅压缩,不会涨到3.0
5.01.52.208超了3.5,增速进一步变慢,和3.0的差距被缩小

可以明显看到:
刚超过阈值时,损失几乎不受影响,过渡非常顺滑;
损失越高,被压缩得越厉害,但始终保持越高越重的排序;
不会像硬截断那样,所有高损失全变成同一个值。

对应代码

loss[high_loss_mask]=self.loss_threshold+torch.log(1+loss[high_loss_mask]-self.loss_threshold)

拆解开:

  1. loss[high_loss_mask] - self.loss_threshold:算出损失超出阈值的部分(增量);
  2. 1 + 增量:加1保证对数的输入大于0,避免出现负数报错;
  3. torch.log(...):对超出的增量做对数压缩,让增量涨得变慢;
  4. self.loss_threshold + 压缩后的增量:把基准阈值加回来,保证阈值点处数值连续、没有台阶。

什么时候用硬截断,什么时候用平滑衰减?

方案场景特点
硬截断确定有大量标注错误,想直接屏蔽极端错标的影响简单粗暴,可控性强,调试方便
平滑衰减样本大多是标注正确的难样本(比如小目标、低对比度),只想削弱、不想完全屏蔽更温和,梯度平稳,训练更稳定,保留难样本的相对差异信息
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/24 9:13:06

5分钟快速上手Penpot:开源设计平台团队协作实战指南

5分钟快速上手Penpot:开源设计平台团队协作实战指南 【免费下载链接】penpot Penpot: The open-source design tool for design and code collaboration 项目地址: https://gitcode.com/GitHub_Trending/pe/penpot 你是否正在寻找一款既专业又灵活的设计工具…

作者头像 李华
网站建设 2026/6/24 9:10:26

Rust为何成为AI智能体视觉(TVA)的“免疫系统”(系列)

前沿技术介绍:AI智能体视觉(TVA,Transformer-based Vision Agent)是依托Transformer架构与“因式智能体”理论所构建的颠覆性工业视觉技术,属于“物理AI” 领域的一种全新技术形态,完成了从“虚拟世界”到“…

作者头像 李华
网站建设 2026/6/24 9:07:03

AI进阶三境界:从聊天框到专家团队,你处于哪一层?

文章将使用AI的方式分为三个层级:第一层是基础聊天框,仅用于简单问答;第二层是通过Agent平台实现任务自动化;第三层是利用中央调度系统,由专家Agent团队协作完成复杂任务。文章深入解析了Agent的运作机制,将…

作者头像 李华
网站建设 2026/6/24 9:03:14

微软 SQL Server 版本演进史:从诞生到 SQL Server 2025

一、SQL Server 的起源(1989-1996)1989年:SQL Server 1.0微软与 Sybase 合作,推出了第一个版本的 SQL Server,运行于 OS/2 平台。这是微软进入数据库领域的起点,虽然功能相对简单,但为后续发展奠…

作者头像 李华
网站建设 2026/6/24 8:51:34

小程序分销裂变怎么做?实体门店二级分销落地全流程拆解

公域投流成本持续上涨,老客分销成为实体店性价比最高的获客方式。但很多商家开通分销后效果极差:佣金设置不合理、上下级绑定混乱、数据对账繁琐,核心是没有吃透分销底层逻辑。本文结合SaaS商城通用分销模块,拆解完整落地流程。一…

作者头像 李华
网站建设 2026/6/24 8:51:14

AI法律公司Garfield AI赢得首次诉讼案件

英国首家获得监管批准的AI法律公司Garfield AI,在成立约一年后赢得了首次法庭诉讼,成为AI法律服务领域的重要里程碑。案件背景自由职业者塔米雷斯卡迈尔塔基迪尔曾为一家餐饮酒店企业提供人力资源相关服务,在尝试通过庭外途径解决欠费纠纷未果…

作者头像 李华