FP16 and Apex
For recent Nvidia GPUs like V100 or R8000, allowing half-precision training can get roughly 2X speedup instantly, and the only problem is how to conduct mixed-precision training?
Unlike TPU, typical Nvidia GPUs follow the IEEE fp16 standard instead of bfp16. Specifically, TPU uses the bfp16 standard with 8 exponent bits and 7 mantissa bits, while fp16 has only 5 exponent bits and 10 mantissa bits. With less exponent bits, directly applying half-precision to deep learning leads to tons of overflow and underflow.
To compensate the loss of exponent bits, dynamic loss scaling and mixed-precision training have been proposed. The fairseq package comes with naive fp16 support, and for custom models and other PyTorch codebases, the Apex package is usually the go-to choice. Besides the apex documentation, below provides some tips for mixed-precision training.
Dynamic Loss Scaling
TL;DL. Since it is easier to detect overflow then under, a minimal loss scale is recommended to set (e.g., 0.03125) and a small window helps to stabilize the training (e.g., 256).
Unlike underflow, the overflow can be easily detected. Accordingly, it is strategically benefical to make the model overflow to avoid underflow. Specifically, loss scaling is proposed to first scale-up the loss and gradients by a constant, then divide the gradient by the same constant after back-propagation. Still, the choice of the constant is an open problem. Dynamic loss scaling is leveraged to dynamically update this constant during training. Specifically, it starts from setting the constant to a large value, then halve its value when an overflow occurs, which indicates the constant value is probably too large. Also, if no overflow is detected in N continuous updates which indicates the constant could be larger, it duplicates the constant.
Mixed Precision Training
TL;DL. Elementwise_affine parameters are easier to overflow and it is helpful to cast this part to fp32.
Others
- In apex,
opt_level
can be set toO0
(full fp32),O1
(mixed precision),O2
(almost fp16), andO3
(full fp16). - To specifically cast a model to fp32:
- set model parameters, e.g.,
for n, p in model.named_parameters(): if any([ki in n for ki in fp32_keys]): p.float()
- cast precision conversion by monkey patching, e.g.,
orig_linear = torch.nn.functional.linear def wrapped_linear(*args): casted_args = [] for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): casted_args.append(arg.float()) else: casted_args.append(arg) return orig_linear(*casted_args) torch.nn.functional.linear = wrapped_linear
- set model parameters, e.g.,