摘要:
少样本学习领域最近有了长足的进步。这些进步中的大多数来自将少样本学习构建为元学习问题。目前,Model Agnostic Meta Learning或MAML是通过元学习进行少样本学习的最佳方法之一。MAML简单,优雅且功能强大,但是它具有许多问题,例如对神经网络结构非常敏感,通常会导致训练过程中的不稳定,需要艰巨的超参数搜索来稳定训练并实现高泛化,在训练和推理时都非常耗费算力。在本文中,我们提出了对MAML的各种修改,这些修改不仅可以稳定系统,而且可以大大提高MAML的泛化性能,收敛速度和计算开销,我们称之为MAML++。
MAML
3.1 MAML的问题
MAML的简单,优雅和高性能使其成为元学习的非常强大的框架。但是,MAML也有许多问题,使其难以使用。
梯度不稳定性:
如上图所示,受到神经网络结构和全局超参数设置的影响,MAML在训练过程中可能非常不稳定。优化outer loop涉及多次穿过由同一网络组成的未展开的inner loop进行导数的反向传播。仅此一项就可能导致梯度问题。但是,模型架构进一步加剧了梯度问题,标准4层卷积网络但没有skip-connections。缺少任何skip-connections意味着每个梯度必须多次通过每个卷积层。实际上,梯度将被多次乘以相同的参数集。经过多次反向传播后,展开网络的深度结构和skip-connections的缺失会分别引起梯度爆炸和梯度消失问题。
二阶导数成本:
通过梯度更新步骤进行优化需要计算二阶梯度,而二阶梯度的计算成本非常高昂。MAML的作者建议使用一阶近似将处理速度提高三倍,但是使用这些近似可能会对最终的泛化误差产生负面影响。已经在Reptile(Nichol et al., 2018)中尝试了进一步使用一阶方法的尝试,作者在基本模型上应用标准SGD,然后更新其初始化参数向N步更新后的参数方向迈出一步。Reptile的结果变化较大,在某些情况下超过MAML,而在另一些情况下则不如MAML。尚未提出减少计算时间而不牺牲泛化性能的方法。
缺少Batch Normalization统计量累积:
影响生成性能的另一个问题是原始MAML论文中在实验中使用Batch Normalization的方式。不是累积运行统计信息,而是将当前batch的统计信息用于Batch Normalization。这导致Batch Normalization的效果较差,因为学习的偏差必须适应各种不同的均值和标准差,而不是单个均值和标准差。另一方面,如果Batch Normalization使用累积的运行统计信息,则最终将收敛到某些全局平均值和标准偏差。这样就只剩下一个均值和标准偏差来学习偏差了。使用running统计信息而不是batch统计信息,可以极大地提高收敛速度,稳定性和泛化性能,因为归一化的特征将导致更平滑的优化环境(Santurkar et al.,2018)。
共享(跨step)Batch Normalization偏差:
MAML中的批处理规范化的另一个问题源于以下事实:Batch Normalization偏差未在inner loop中更新;相反,在基础模型的所有迭代中都使用相同的偏差。隐式地执行此操作将假定所有基本模型在整个inner loop更新中都是相同的,因此通过它们传递的特征具有相同的分布。这是一个错误的假设,因为在每次inner loop更新时,都会实例化一个新的基础模型,该基础模型与前一个基础模型的差异足以从偏差估计的角度将其视为新模型。因此,为基本模型的所有迭代学习单个偏差集会限制性能。
共享的inner loop(跨step和跨参数)学习率:
影响泛化和收敛速度(就训练迭代而言)的一个问题是对所有参数和所有更新步骤使用共享学习率的问题。这样做会带来两个主要问题。具有固定的学习率要求进行多次超参数搜索,以找到特定数据集的正确学习率; 根据搜索的完成方式,此过程可能在计算上非常昂贵。
(Li et al。,2017)中的作者建议为网络的每个参数学习学习率并更新方向。这样做解决了手动搜索正确学习率的问题,并且还允许各个参数具有较小或较大的学习率。然而,这种方法带来了自己的问题。由于网络包含40K到50K的参数(取决于数据点的维数),因此学习每个网络参数的学习率意味着要增加计算量并增加内存使用量。
固定outer loop学习率:
在MAML中,作者使用具有固定学习率的Adam来优化元目标。事实证明,使用阶跃或余弦函数对学习率进行退火对于在多种情况下实现最新的泛化性能至关重要(Loshchilov & Hutter, 2016; He et al., 2016; Larsson et al., 2016; Huang et al., 2017)。因此,我们认为使用静态学习率会降低MAML的泛化性能,这也可能是优化速度较慢的原因。此外,具有固定的学习速率可能意味着必须花费更多(计算)时间来调整学习速率。
稳定,自动和改进的MAML
在本节中,我们提出了解决MAML框架问题的方法,如第3.1节所述。每个解决方案都有一个与要解决的问题相同的参考。
梯度不稳定性→多步损失优化(MSL):
MAML最小化完成对support set任务的所有inner-loop更新后的基础网络所计算出的在target set的loss。相反,我们建议最小化完成对support set任务的每一步更新的基础网络所计算出的在target set的loss。更具体地说,我们建议最小化的loss是每步support set loss更新后target set loss的加权总和。更正式地:
其中$\beta$是学习率,$L_{T_b}(f_{\theta^b_i})$表示在$i$向最小化support set任务loss更新后的基本网络权重在任务$b$的target set loss,$v_i$表示步骤$i$中target set loss的重要性权重, 用于计算加权和。
通过使用上面提出的multi-step loss,我们改善了梯度传播,因为现在每一步的基础网络权重都直接(对于当前步loss)和间接(来自后续步的loss)接收梯度。使用第3节中描述的原始方法,由于反向传播,除最后一步外,每个步骤的基础网络权重都被隐式优化,这导致了MAML的许多不稳定问题。但是,如图1所示,使用multi-step loss可以缓解此问题。此外,我们对每步损耗采用了退火加权。最初,所有损失都对损失具有相同的贡献,但是随着迭代次数的增加,我们会减少早期步骤的权重,并逐渐增加后续步骤的权重。这样做是为了确保随着训练的进行,最终步数loss会受到优化器的更多关注,从而确保其达到可能的最低损失。如果不使用退火,我们发现最终损失可能会高于原始方法。
二阶导数成本→导数退火(DA):
使MAML具有更高的计算效率的一种方法是减少所需的inner-loop更新次数,这可以通过本报告后续部分中介绍的某些方法来实现。但是,在本段中,我们提出了一种直接减少per-step计算开销的方法。MAML的作者提出了梯度导数的一阶近似的用法。但是,他们在整个训练阶段都采用了一阶近似。相反,我们建议随着训练的进行对微分阶数进行退火。更具体地说,我们建议在训练阶段的前50个epochs使用一阶梯度,然后在训练阶段的其余时间使用二阶梯度。我们凭经验证明,这样做可以大大加快前50个epochs的速度,同时允许进行二阶训练,以实现二阶梯度提供给模型的强大泛化性能。另一个有趣的观察结果是,与更不稳定的仅二阶实验相反,微分阶数退火实验没有出现梯度爆炸或消失的事件。在开始使用二阶导数之前使用一阶可以用作一种强大的预训练方法,该方法可以学习不太可能产生梯度爆炸/减小问题的参数。
缺少Batch Normalization统计信息累积→Per-Step Batch Normalization运行统计信息(BNRS):
在MAML Finn et al. (2017)的原始实现中,作者仅使用当前batch统计信息作为Batch Normalization统计信息。我们认为,这导致了3.1节中描述的各种不良影响。为了缓解这些问题,我们建议使用running batch统计信息进行Batch Normalization。要在MAML上下文中简单地实现Batch Normalization,就需要在inner-loop fast-knowledge获取过程的所有更新步骤之间共享running batch统计信息。然而,这样做将导致不希望的结果,即所存储的统计信息将在网络的所有inner loop更新之间共享。由于能在跨网络参数的多个更新上工作的学习参数的复杂性不断提高,因此这将导致优化问题,并有可能减慢或完全停止优化。更好的替代方法是按步骤收集统计信息。要按步骤收集running统计信息,需要实例化网络中每个Batch Normalization层的N组running平均值和running标准偏差集(其中N是inner loop更新步骤的总数),并使用优化过程中采取的步骤分别更新running统计信息。per-step batch normalization方法应加快MAML的优化速度,同时潜在地提高泛化性能。
共享(跨步骤)Batch Normalization偏差→Per-Step Batch Normalization权重和偏差(BNWB):
在MAML论文中,作者训练他们的模型去学到对每一层的一组偏差。这样做是假设通过网络传递的特征的分布是相似的。但是,这是一个错误的假设,因为基本模型已更新了许多次,从而使特征分布彼此之间越来越不相似。为了解决这个问题,我们建议在inner-loop更新过程中每步学习一组偏差。这样做意味着Batch Normalization将学习特定于在每个集合处看到的特征分布的偏差,这将提高收敛速度,稳定性和泛化性能。
共享的inner loop学习率(跨步和跨参数)→学习每层每步学习率和梯度方向(LSLR):
Li et al. (2017)的先前工作,证明了学习基础网络中每个参数的学习率和梯度方向可以提高系统的泛化性能。然而,这导致参数数量增加和计算开销增加的结果。因此,我们建议改为学习网络中每一层的学习率和方向,以及随着基础网络的逐步适应而学习不同的学习率。学习每层而不是每个参数的学习率和方向应该减少所需的内存和计算,同时在更新步骤中提供更多的灵活性。此外,对于学习到的每个学习率,将有N个实例的学习率,每个步骤要采用一个实例。通过这样做,参数可以自由学习在每步降低学习率,这可以帮助减轻过拟合的情况。
固定outer loop学习率→元优化器学习率的余弦退火(CA):
在MAML中,作者在元模型的优化器上使用静态学习率。通过使用阶跃函数(He et al., 2016)或余弦函数(Loshchilov&Hutter,2016)对学习率进行退火已被证明对于具有更高泛化能力的学习模型至关重要。余弦退火调度在产生最新技术结果方面特别有效,同时消除了对学习速率空间进行任何超参数搜索的需求。因此,我们建议将余弦退火调度应用于元模型的优化器(即元优化器)。退火学习率可使模型更有效地拟合训练集,结果可能会产生更高的泛化性能。