1 简介
批归一化(Batch Normalization,BN)是在2015年由Sergey Loffe和Christan Szegedy[1]提出的一种加速深度学习模型收敛的方法。
在模型的训练过程中,每个深度学习模型每一层模块的输出分布都在不断变化,后续的模块需要不断适应新的输入模式,这个问题在Batch Normalization[1]中被称为internal covariate shift。为克服这个问题,Batch Normalization提出在模型内部加入归一化层。归一化层的引入使得模型的训练更加稳定,允许使用更大的学习率,使得模型对参数的初始化没那么敏感。
本文主要谈一谈BN的运行原理和实现细节,最后分享一些BN的使用经验和容易踩到的坑。
2 实现方法
BN用如下公式作输入样本的归一化: \[ \hat x = \frac{x - E(x)}{\sqrt{Var(x) + \epsilon}}, \tag{1}\] 其中\(x\)为BN模块的输入。\(E(x)\)为\(x\)的数学期望,\(Var(x)\)为\(x\)的方差,\(\epsilon\)是防止除零异常的一个接近\(0\)的正数。
随后,归一化的样本经过一层线性层得到BN的输出: \[ y = \hat x\gamma + \beta. \]
2.1 Batch size
实际训练时,输入的tensor是N维的。以图像为例,一般图像特征\(x\)是4维,形状可能是\(b\times d\times h \times w\),其中\(b\)为batch size,\(d\)为特征维度,\(h\)和\(w\)为图像的长宽。在这个例子中,BN对所有\(d\)维的特征向量作归一化(共\(b\times h\times w\)个向量)。
为了获得尽量准确的统计,batch size最好取尽量大些。如果batch size太小,那么\(E(x)\)和\(Var(x)\)的估计不准确,模型的最终性能便可能下降。
2.2 测试和训练阶段的行为不一致
测试推理阶段,模型往往一次只接受一个数据:\(\text{batch size}=1\)。BN不能像训练时那样在大batch size下估计\(E(x)\)和\(Var(x)\). 为了应对这个问题,BN的对策是moving average。在训练阶段,BN使用Batch内统计的均值和方差作归一化,同时使用moving average方法维护均值和方差预备测试时使用。设BN的momentum
参数等于0.1,\(m\)是moving average方法跟踪的一个统计量,那么其更新方法为: \[
\hat m_{t} = \hat m_{t-1} \cdot (1 - \text{momentum}) + m_t \cdot \text{momentum}.
\]
测试时用事先统计的均值和方差的moving average,带入公式 1中作归一化。
3 代码实现
在实现BatchNorm之前,我们不妨先看看pytorch官方的BatchNorm2d
模块,观察BatchNorm层要有哪些参数:
import torch
import torch.nn as nn
= 2, 32, 128, 128
batch, ch, h, w = nn.BatchNorm2d(num_features=ch)
torch_batch_norm for k, v in torch_batch_norm.named_parameters():
print(f'parameter: {k}', v.shape)
for k, v in torch_batch_norm.named_buffers():
print(f'buffer: {k}', v.shape)
parameter: weight torch.Size([32])
parameter: bias torch.Size([32])
buffer: running_mean torch.Size([32])
buffer: running_var torch.Size([32])
buffer: num_batches_tracked torch.Size([])
注意到BatchNorm的参数有两种。一种是parameter(weight
、bias
),一种是buffer(running_mean
、running_var
、num_batches_tracked
)。对于parameter,torch默认其参数是需要梯度反传的;而buffer则用于存储一些不需要梯度反传的模型参数。与parameter
一样,在保存模型时,buffer
参数也会存储到state_dict
中。
下面是本文提供的BN实现:
import torch
import torch.nn as nn
class MyBatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
# weight初始化为1,bias初始化为0.
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.momentum = 0.1
self.eps = eps
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0.))
def forward(self, tensor):
= tensor.shape
bs, ch, h, w # 变换一下tensor的尺寸,方便处理
= tensor.permute(0, 2, 3, 1).flatten(0, 2) # bs * h * w, ch
tensor_flatten # 求均值和方差
= torch.mean(tensor_flatten, 0)
mean # 注意方差有biased和unbiased两种。
= torch.var(tensor_flatten, 0, unbiased=False)
var = torch.var(tensor_flatten, 0, unbiased=True)
var_unbiased
if self.training:
# 训练时,我们要执行moving average,统计
# running_mean和running_var,注意此时应
# 使用unbiased版本的方差。
self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
self.running_var.mul_(1 - self.momentum).add_(self.momentum * var_unbiased)
self.num_batches_tracked.add_(1)
# 训练时用batch内的统计量,测试时用moving average
# 保存的统计量。
if self.training:
= (tensor_flatten - mean) / torch.sqrt(var + self.eps)
tensor_flatten else:
= (tensor_flatten - self.running_mean) / torch.sqrt(self.running_var + self.eps)
tensor_flatten
# 归一化完成后,做线性变换。
= tensor_flatten * self.weight + self.bias
ret = ret.view(bs, h, w, ch).permute(0, 3, 1, 2)
ret return ret
接下来验证看看MyBatchNorm
的行为和torch.nn.BatchNorm2d
是否完全一致。
我们先检查训练模式下两者的行为:
= MyBatchNorm(num_features=ch)
my_batch_norm # 因为BN涉及running_mean和running_var的更新,所以我们要多跑几轮来检查moving average的正确性。
for _ in range(10):
= torch.rand(batch, ch, h, w)
a = torch_batch_norm(a)
ret1 = my_batch_norm(a)
ret2 = torch.mean(torch.abs(ret1 - ret2)).item()
diff
= torch_batch_norm.running_mean
running_mean1 = my_batch_norm.running_mean
running_mean2 = torch.mean(torch.abs(running_mean1 - running_mean2)).item()
diff_mean
= torch_batch_norm.running_var
running_var1 = my_batch_norm.running_var
running_var2 = torch.mean(torch.abs(running_var1 - running_var2)).item()
diff_var print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
可以看到,输出的所有误差项都为0!这表明MyBatchNorm
的实现和torch的BN相吻合。
不要忘记BN在训练时的行为和测试时的行为不同。我们需要再检查一遍测试阶段下MyBatchNorm
的行为。
# .eval()开启测试模型
eval()
torch_batch_norm.eval()
my_batch_norm.for _ in range(10):
= torch.rand(batch, ch, h, w)
a = torch_batch_norm(a)
ret1 = my_batch_norm(a)
ret2 = torch.mean(torch.abs(ret1 - ret2)).item()
diff
= torch_batch_norm.running_mean
running_mean1 = my_batch_norm.running_mean
running_mean2 = torch.mean(torch.abs(running_mean1 - running_mean2)).item()
diff_mean
= torch_batch_norm.running_var
running_var1 = my_batch_norm.running_var
running_var2 = torch.mean(torch.abs(running_var1 - running_var2)).item()
diff_var print('{:.6f};{:.6f};{:.6f}'.format(diff, diff_mean, diff_var))
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
0.000000;0.000000;0.000000
至此我们已经完成了全部检查,实验结果表明MyBatchNorm
的实现是正确的。
4 正确地使用BN
BN作为一种实用、应用广泛的归一化模块,是计算机视觉领域的一座里程碑。尽管BN的应用确实解决了一些实际问题,但它也存在一些“坑”,是在使用BN时应当注意的。
4.1 BN在训练阶段与测试阶段的行为差异
在训练阶段,BN使用batch内统计的均值和方差作归一化,并记录它们的moving average;而在测试阶段,BN不再统计新数据的均值和方差,也不再更新moving average。这种不一致性(inconsistency)在后续工作[2]中被认为是一种影响性能的潜在因素。
4.2 如何正确地冻结BN模块
设想我们有一个模型经过了充分的预训练,现在我们希望在一个小数据集上微调它。一般步骤包括(以pytorch为例):
- 阻止梯度反传。这可以通过使用
torch.no_grad()
或将该各参数的requires_grad
属性设置为False
做到; - 调用
module.eval()
,关闭train
模式;
针对第2点,一般人们有两种意见。一种看法认为不开BN的eval
模式更好,这有助于让模型学习如何对新数据做归一化。而我倾向于采取的做法则是开启eval
。在我的经验中,如果BN处于训练状态,而模型的其它层则冻结著,那么模型可能因为不适应BN在新数据上归一化参数的改变而引发训练不稳定。
总而言之,BN在训练、迁移学习、测试时的行为不一致有时确实是一个麻烦的问题。如果遇到了这个问题,我建议考虑一下是否要开启BN的eval
模式,或者试试后来的Group Normalization[2]。
4.3 分布式训练
在训练参数量较大的模型时,可以用分布式训练,利用多个进程和多个计算设备执行计算。这种情况下,每张卡只需负责比较小的batch。注意原始的BN在batch size较小时,所产生的均值/方差的统计量不准确。因此,在分布式训练时,我们最好将原BatchNorm模块替换为torch.SyncBatchNorm
。后者能同步所有计算设备,在更大的batch size上统计均值和方差。
4.4 不要递归地使用BN
最后介绍一个我踩过的,印象深刻的坑。 假如有这样一段代码:
= batch_norm(x1)
x2 = batch_norm(x2) x3
或
= conv(x1) # conv中包含BN模块
x2 = conv(x2) x3
前面的一段代码的问题或许容易识别,后者的问题则稍隐蔽些。你能预测到会发生什么吗?在这样的代码中,同一个BN模块在训练时会分别获取x2
和x1
的均值和方差,然后通过moving average将它们计入running_mean
和running_var
。然而,由于x1
和x2
服从不同的分布,因此running_mean
和running_var
的统计将失去意义。 问题的表现是,在训练阶段,我们会观察到损失正常下降。测试时,我们开启eval
模式,模型的表现不如预期;可是如果你关闭eval
模式,也许会发现模型又能正常工作。
类似的问题也存在于特征金字塔(FPN)的实现中。如果你希望在类特征金字塔的结构中实现不同层级共享参数的话,注意卷积的参数也许能共享,但BN的参数不要共享。
5 总结
本文介绍了BN的工作原理,给出了一种基于pytorch的BN模块实现,并提供了详细的代码检查。最后,本文讨论了应用BN过程中容易遇到的几种问题。
在接触深度学习的过程中,Batch Normalization是一个让我反复(大概得有两三次吧)踩坑的模块,每次踩坑都得琢磨好久才能发现问题所在。现在我已经习惯性的选择Group Normalization[2],抛弃BN了。尽管如此,BN仍是一个经典的工作,它背后的思想很值得学习研究。