详解大型项目中的AMP训练

1 · HikariLi · Jan. 26, 2024, 6:46 a.m.
1 什么是AMPAutomatic Mixed Precision是百度联合英伟达一起推出的一个训练trick,通过在训练过程中部分使用FP16的半精度数据来极大节省内存,同时能加快训练速度。最开始要使用Apex框架来开启AMP训练,但现在PyTorch已经自带AMP相关功能。 2 AMP训练的挑战AMP训练一般会遇到几个问题,第一是有可能遇到数值下溢和数值上溢,由于FP16能表示的范围要小很多,所以当数据转换为FP16之后可能会发生下溢出(0),和上溢出(inf),为了解决这个问题,我们需要调用torch.cuda.amp.GradScaler()来缩放梯度,让梯度始终在FP16的表示范围内。而缩放的大小torch会自动帮我们决定。在缩放梯度进行反向传播之后,我们需要在优化器step前将其缩放回去,这样才不会与原始设定的学习率的尺度产生冲突。另外一个问题是在某些算子中(例如BatchNorm),AMP训练会造成不稳定,我们最好是能够随时监控梯度的大小。 3 大厂项目中AMP训练的代码12345678class NativeScalerWithGradNormCount: st...