网易首页 > 网易号 > 正文 申请入驻

持续学习中避免灾难性遗忘的EWC损失数学原理及代码实现

0
分享至

训练人工神经网络最重要的挑战之一是灾难性遗忘。神经网络的灾难性遗忘(catastrophic forgetting)是指在神经网络学习新任务时,可能会忘记之前学习的任务。这种现象特别常见于传统的反向传播算法和深度学习模型中。主要原因是网络在学习新数据时,会调整权重以适应新任务,这可能会导致之前学到的知识被覆盖或忘记,尤其是当新任务与旧任务有重叠时。



在本文中,我们将探讨一种方法来解决这个问题,称为Elastic Weight Consolidation。EWC提供了一种很有前途的方法来减轻灾难性遗忘,使神经网络在获得新技能的同时保留先前学习任务的知识。



在任务a和任务B的灰色和黄色区域中,存在许多具有期望的低误差的最优参数配置。假设我们为任务A找到了一个这样的配置θꭺ*,当继续从这样的配置训练模型到新的任务B时,会出现三种不同的场景:

蓝色箭头:简单地继续在任务B上进行训练而不受惩罚,将在任务B的低水平区域结束,但在任务A上的表现低于预期的准确性。

绿色箭头:使用任务A的权重的L2约束可能太强,使得模型在任务A上表现良好,但在任务B上表现不佳。

红色箭头:这是EWC是提出的解决方案,它将在模型在两个任务上都表现良好的区域(两个区域之间的交叉点)中找到参数。

下面我们将解释这是如何完成的。

费雪信息矩阵(FIM)

EWC方法所基于的FIM(Fisher Information Matrix)。FIM是一种统计度量,用于量化给定数据提供的关于我们要估计的未知参数θ的信息量。在持续学习的背景下,FIM将有助于识别神经网络参数,这些参数从以前的任务中获取的数据信息较少。通过更新这些参数,网络可以学习新的任务,而不会删除存储在参数中的重要信息,这些信息是关于先前学习任务的非常有用的信息。

假设X是一个随机变量,其概率密度函数f(X |θ)参数化为θ。样本x的似然函数(仅在数据固定的情况下为参数函数)为:

和对数拟然:

将FIM定义为:



这表明对数似然函数对参数的微小变化有多敏感。我们可以将FIM视为似然函数二阶导数的负期望:



当求二阶导数时,基本上是在看似然函数的曲率。

可以考虑下面的两个绘制的似然函数的图表。蓝色曲线表示在峰值附近非常窄的分布,表明数据更有可能在θ附近,并且随着远离θ而迅速减少。相反,黑色曲线代表一个更广泛的分布,即使远离θ,数据也保持相似的可能性。

FIM量化了这个概念——数据是多么严格地限制在某个θ值上。较大的FIM(如蓝色曲线所示)意味着参数值的微小变化将导致数据在这些参数下的可能性显著下降。相反,较小的FIM(如黑色曲线所示)意味着参数值的较小变化将导致可能性的较小降低。



事实证明,费雪信息矩阵与数据的方差(或多变量情况下的协方差)成反比。在上面的图表中,如果假设曲线分别代表均值θ 0和方差σ²ᵦₗᵤₑ和σ²ᵦₗ꜀ₖ的两个高斯分布,其中σ²ᵦₗᵤₑ< σ²ᵦₗ꜀ₖFIM等于1/σ²,因此蓝色曲线包含更多信息。

弹性重量固结

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

弹性权重保持

弹性权重保持(Elastic Weight Consolidation,EWC)是一种用于减轻神经网络灾难性遗忘问题的方法。它的基本思想是在学习新任务时保护先前任务的关键权重。

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:



对两边应用对数并不改变最大化的目标,因为对数是一个单调变换。因此目标变成:



假设两个独立任务D = {A, B},我们有:



最后一个是独立于A和B的。这里log(p(B|θ))是任务B的损失,log(p(B))是B的可能性,它可以作为优化的常数,因为它不依赖于θ, log(p(θ| a))是任务a的后验分布,它包含了任务a重要参数的所有信息。

估计log(p(θ|A))比较复杂的,因为计算它将涉及在整个参数空间上对高维函数进行积分。但是它近似为正态分布,其均值为任务a - θꭺ*的最优参数,方差为费雪信息矩阵。这种近似是有意义的,因为我们可以假设A和B任务的新参数θ与任务A的最优参数相差不远。在所有θꭺ*的参数中,会有一些参数对任务A的良好表现更重要,并且不希望它们改变太多,这就是FIM的作用,FIM的值表明在这种情况下,改变某个参数将如何影响任务A的损失。因此,FIM中值越高的参数变化受到的惩罚越大。

现在,我们对任务A的最优权值进行泰勒展开直到第二项

其中log(p(θꭺ*|A))是一个常数,我们可以在优化中忽略它。我们也可以忽略第二项,因为在最优θꭺ*处,梯度为零。这样就找到了log(p(θ|A))的表达式,把它代回到图8的原始公式中:



第二项的二阶导数为Hessian,可以根据图5的定义用费雪信息矩阵近似。log(p(B|θ))是新任务B的损失,例如交叉熵,我们记为Lᵦ(θ)

我们不需要进行二阶导数,只需根据图4中等价于图5的定义,即对数似然梯度的外积,用一阶导数近似FIM即可:



这样优化L(θ)的总损失为:



λ是一个超参数,表示在前一个任务a上保持精度的重要性。

上面涉及梯度向量的外积的定义捕获了梯度的协方差结构。而FIM的对角线近似通常由梯度的平方给出,它只计算参数的方差,但计算成本较低,足以完成任务:



Pytorch实现

上面我们介绍了弹性权重保持的数学原理,下面我们来看看Pytorch的代码实现

让我们首先导入一些库以及分别代表任务A和任务B的MNIST和Fashion MNIST数据集。我们还定义了一个简单的神经网络:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
def get_accuracy(model, dataloader):
model = model.eval()
acc = 0
for input, target in dataloader:
o = model(input.to(device))
acc += (o.argmax(dim=1).long() == target.to(device)).float().mean()
return acc / len(dataloader)
class LinearLayer(nn.Module):
# from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
super(LinearLayer, self).__init__()
self.use_bn = use_bn
self.lin = nn.Linear(input_dim, output_dim)
self.act = nn.ReLU() if act == 'relu' else act
if use_bn:
self.bn = nn.BatchNorm1d(output_dim)
def forward(self, x):
if self.use_bn:
return self.bn(self.act(self.lin(x)))
return self.act(self.lin(x))
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)
class Model(nn.Module):
def __init__(self, num_inputs, num_hidden, num_outputs):
super(Model, self).__init__()
self.f1 = Flatten()
self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)
self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)
self.lin3 = nn.Linear(num_hidden, num_outputs)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(self.f1(x))))
# Load MNIST dataset, representint task A
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
# FashiomMNIST is task B
f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

现在让我们在MNIST任务上训练模型:

# parameters
EPOCHS = 4
lr=0.001
weight=100000
accuracies = {}
device = 'cuda:1'
criterion = nn.CrossEntropyLoss()
# train model on task A
model = Model(28 * 28, 100, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr)
for _ in range(EPOCHS):
for input, target in tqdm(train_loader):
output = model(input.to(device))
loss = criterion(output, target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracies['mnist_initial'] = get_accuracy(model, test_loader)

现在可以定义函数来估计FIM和EWC损失中使用的先前参数:

def ewc_loss(model, weight, estimated_fishers, estimated_means):
losses = []
for param_name, param in model.named_parameters():
estimated_mean = estimated_means[param_name]
estimated_fisher = estimated_fishers[param_name]
losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
return (weight / 2) * sum(losses)
def estimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
estimated_mean = {}
for param_name, param in model.named_parameters():
estimated_mean[param_name] = param.data.clone()
estimated_fisher = {}
dl = DataLoader(train_ds, batch_size, shuffle=True)
for n, p in model.named_parameters():
estimated_fisher[n] = torch.zeros_like(p)
model.eval()
for i, (input, target) in enumerate(dl):
if i > num_batch:
break
model.zero_grad()
output = model(input.to(device))
# https://www.inference.vc/on-empirical-fisher-information/ - more on this here
if ESTIMATE_TYPE == 'empirical':
# empirical
label = target.to(device)
else:
# true estimate
label = output.max(1)[1]
loss = F.nll_loss(F.log_softmax(output, dim=1), label)
loss.backward()
# accumulate all the gradients
for n, p in model.named_parameters():
estimated_fisher[n].data += p.grad.data ** 2 / len(dl)
estimated_fisher = {n: p for n, p in estimated_fisher.items()}
return estimated_mean, estimated_fisher

然后继续在任务B上训练EWC损失的网络:

# compute fisher and mean parameters for EWC loss
estimated_mean, estimated_fisher = estimate_ewc_params(model, mnist_train)
# Train task B fashion mnist
for _ in range(EPOCHS):
for input, target in tqdm(f_train_loader):
output = model(input.to(device))
loss = ewc_loss(model, weight, estimated_fisher, estimated_mean) + criterion(output, target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracies['mnist_EWC'] = get_accuracy(model, test_loader)
accuracies['f_mnist_EWC'] = get_accuracy(model, f_test_loader)

可以得到以下精度:

{'mnist_initial': tensor(0.9772, device='cuda:1'),
'mnist_AB': tensor(0.9717, device='cuda:1'),
'f_mnist': tensor(0.8312, device='cuda:1')}

最后将这些与没有EWC损失的模型进行比较:

{'mnist_initial': tensor(0.9762, device='cuda:1'),
'mnist_AB': tensor(0.1769, device='cuda:1'),
'f_mnist': tensor(0.8672, device='cuda:1')}

可以看到EWC损失有助于保持任务A的准确率几乎不变,而学习任务B的准确率几乎与没有EWC损失的情况相同。

总结

我们看到了一种允许神经网络在继续学习新任务的同时保留其先前学习的知识的技术,虽然EWC在解决灾难性遗忘方面效果显著,但仍有一些挑战,例如对费雪信息矩阵的计算和存储需求较高,以及在复杂的深度神经网络结构中的实施复杂性。

还有还有其他方法可以使模型进行持续学习,比如:

重播记忆(Replay Memory):保存旧数据以便周期性地重训练。

联合训练(Joint Training):同时训练网络以处理旧任务和新任务。

元学习方法(Meta-learning Approaches):通过元学习算法来优化模型,以便快速适应新任务而不会忘记旧任务。

这些方法有助于减轻灾难性遗忘的影响,使神经网络能够持续学习和适应多个任务。

https://avoid.overfit.cn/post/56aee34117764e89a1a707c316fa305f

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
性别争议奥运冠军为男性!JK罗琳发文,名嘴呼吁将金牌还给中国

性别争议奥运冠军为男性!JK罗琳发文,名嘴呼吁将金牌还给中国

全景体育V
2024-11-05 08:19:36
台媒爆黄春梅被气病,住院治疗,只因汪小菲断供,心疼女儿压力大

台媒爆黄春梅被气病,住院治疗,只因汪小菲断供,心疼女儿压力大

娱圈小愚
2024-11-05 09:00:14
你们去过埃及的人嘴真严啊!网友:一生报喜不报忧的中国人

你们去过埃及的人嘴真严啊!网友:一生报喜不报忧的中国人

观察鉴娱
2024-11-04 11:06:32
闹大了!警车撞飞女孩10多米被指超速,涉事民警说辞又被网友质疑

闹大了!警车撞飞女孩10多米被指超速,涉事民警说辞又被网友质疑

火山诗话
2024-11-04 20:30:27
中方正式下达逐客令,美方被赶出中国!美高层火速对华做保证

中方正式下达逐客令,美方被赶出中国!美高层火速对华做保证

现代小青青慕慕
2024-11-03 20:58:01
美国人发现,中国神舟十八号,带回全球首个太空科技?

美国人发现,中国神舟十八号,带回全球首个太空科技?

Thurman在昆明
2024-11-05 16:47:33
美国大选,传来最新消息!

美国大选,传来最新消息!

数据宝
2024-11-05 12:24:58
精液啥味道?竟然有人会觉得“香”有人觉得“苦”

精液啥味道?竟然有人会觉得“香”有人觉得“苦”

图灵灵2024
2024-11-04 11:43:21
战火纷飞!库尔斯克地区遭遇乌克兰军方空袭,8000名士兵陷入困境

战火纷飞!库尔斯克地区遭遇乌克兰军方空袭,8000名士兵陷入困境

世界探索者发现
2024-11-04 21:03:33
42岁姚笛加拿大吃中餐,鼻翼变大素颜认不出,近况和文章迥然不同

42岁姚笛加拿大吃中餐,鼻翼变大素颜认不出,近况和文章迥然不同

花花lo先森
2024-11-05 11:32:15
明明是忽悠老外,乱港蟑螂拍的“重庆贫民窟”,却把老外看破防了

明明是忽悠老外,乱港蟑螂拍的“重庆贫民窟”,却把老外看破防了

飞花文史
2024-11-05 09:25:16
窦骁中国杯帆船赛夺冠,网友:毅力太强了,值得学习!

窦骁中国杯帆船赛夺冠,网友:毅力太强了,值得学习!

极目新闻
2024-11-05 12:09:46
动完心脏手术回来了!74岁赵少康叹“死而复生”,称未来要做一个更好的人

动完心脏手术回来了!74岁赵少康叹“死而复生”,称未来要做一个更好的人

海峡导报社
2024-11-05 16:26:08
曾经很火,如今却“沦为笑柄”的5件家居用品,你买过几个呢?

曾经很火,如今却“沦为笑柄”的5件家居用品,你买过几个呢?

阿离家居
2024-11-04 11:07:40
盘后突发,证券市场传来王炸消息,明天的A股剧本直接定调了!

盘后突发,证券市场传来王炸消息,明天的A股剧本直接定调了!

一丛深色花儿
2024-11-05 11:43:45
重磅!明天起执行!房贷利率回到3字头...

重磅!明天起执行!房贷利率回到3字头...

居者
2024-11-05 14:32:38
大米检测出重金属镉超标?“先投毒,再治病”?官方回应

大米检测出重金属镉超标?“先投毒,再治病”?官方回应

环球网资讯
2024-11-05 11:56:14
首秀险送助攻,申花17岁小将一战成名 马莱莱续约有转机 全队放假

首秀险送助攻,申花17岁小将一战成名 马莱莱续约有转机 全队放假

替补席看球
2024-11-05 18:17:28
已经成为姆巴佩的牺牲品,曝曼城愿意花费1.5亿引进皇马失意球星

已经成为姆巴佩的牺牲品,曝曼城愿意花费1.5亿引进皇马失意球星

星耀国际足坛
2024-11-05 00:20:03
郑钦文出线稳了!萨巴伦卡送大礼,中国一姐即时排名再创新高

郑钦文出线稳了!萨巴伦卡送大礼,中国一姐即时排名再创新高

大秦壁虎白话体育
2024-11-05 07:19:30
2024-11-05 20:34:44
deephub
deephub
CV NLP和数据挖掘知识
1488文章数 1417关注度
往期回顾 全部

科技要闻

字节跳动上半年营收直逼Meta:TikTok狂飙

头条要闻

选前“封关”民调:哈里斯领先特朗普4个百分点

头条要闻

选前“封关”民调:哈里斯领先特朗普4个百分点

体育要闻

一个想改变中国足球的日本人

娱乐要闻

周雨彤风波升级!阴阳怪气遭全网怒怼

财经要闻

超配!高盛:AH股未来一年回报率20%

汽车要闻

新款别克世纪将11月12日上市 预售价48.99万起

态度原创

游戏
本地
家居
旅游
军事航空

《战国王朝》11月7日迎来正式版 1.0追加新要素

本地新闻

塞上青城|是课本里的风吹草低见牛羊

家居要闻

纯粹干净空间 极简米灰色基调

旅游要闻

北京环球影城大巡游本周六起回归

军事要闻

中国空军:在适当时机场合会有更"牛"的重器利器露面

无障碍浏览 进入关怀版