Efficient Training of LLM
This post is mainly based on
- Efficient Training on a Single GPU
- What Every User Should Know About Mixed Precision Training in PyTorch
Data Types
- FP32 / full precision: 4 bytes
- BF16 and FP16 / half-precision: 2 bytes
- INT8: 1 byte
Mixed Precision Training
- Peak float16 matrix multiplication and convolution performance is 16x faster than peak float32 performance on A100 GPUs
torch.amp
- mixed precision training is 1.5x to 5.5x faster over float32 on V100 GPUs
- PyTorch >= 1.6
Code:
import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()
for data, label in data_iter:
optimizer.zero_grad()
# Casts operations to mixed precision
with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
loss = model(data)
# Scales the loss, and calls backward()
# to create scaled gradients
scaler.scale(loss).backward()
# Unscales gradients and calls
# or skips optimizer.step()
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
Enabling TensorFloat32 (TF32) mode
torch.set_float32_matmul_precision(precision)
- Best practice
- High Performance Computing (HPC) applications, regression tasks, and generative networks may simply require full float32 IEEE precision to converge as expected
Huggingface: Efficient Training on a Single GPU
nvidia-ml-py3
- Monitor the memory usage of the models from within Python
- Access information from
nvidia-smi
TrainingArguments, Trainer and Dataset
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()
print_summary(result)
Results:
Time: 57.82
Samples/second: 8.86
GPU memory occupied: 14949 MB.
Anatomy of Model’s Operations
- Transformers architecture includes 3 main groups of operations
- Tensor Contractions
- Linear layers / components of Multi-Head Attention
- Batched matrix-matrix multiplications
- Most compute-intensive part of training a transformer
- Statistical Normalizations
- Softmax / layer normalization
- One or more reduction operations, the result of which is then applied via a map
- Less compute-intensive than tensor contractions
- Element-wise Operators
- Biases, dropout, activations, and residual connections
- Least compute-intensive operations
- Tensor Contractions
Anatomy of Model’s Memory
- Model training requires more VRAM
- Mixed Precision Training
- Optimizer: AdamW
- 18 bytes per model parameter, plus activation memory
- Mixed Precision Inference
- 6 bytes per model parameter, plus activation memory
- Components breakdown
- Model Weights
- 4 bytes * number of parameters for FP32 training
- 6 bytes * number of parameters for mixed precision training (FP32 + FP16 copy)
- Optimizer States
- 8 bytes * number of parameters for normal AdamW (maintains 2 states)
- 2 bytes * number of parameters for 8-bit AdamW optimizers like bitsandbytes
- 4 bytes * number of parameters for optimizers like SGD with momentum (maintains only 1 state)
- Gradients
- 4 bytes * number of parameters for either FP32 or mixed precision training (gradients are always kept in FP32)
- Forward Activations
- Size depends on many factors, the key ones being sequence length, hidden size and batch size
- Temporary Memory
- All kinds of temporary variables which get released once the calculation is done
- How to explicitly free those as soon as they are no longer needed
- Functionality-specific memory
- Generating text using beam search: maintain multiple copies of inputs and outputs
- Model Weights
.forward()
vs .backward()
Execution Speed
- Convolutions and linear layers: 2x flops in the backward compared to the forward (~2x slower)
- Activations are usually bandwidth-limited,
- Typical for an activation to have to read more data in the backward than in the forward
Batch sizes
- Tensor Core Requirements define the multiplier based on the dtype and the hardware
- For FP16 a multiple of 8 is recommended, but on A100 it’s 64
- For parameters that are small, there is also Dimension Quantization Effects to consider
Gradient Accumulation
- Instead of calculating the gradients for the whole batch at once to do it in smaller steps
- Goal: increase the overall batch size to numbers that would never fit into the GPU’s memory
- Potentially slow down the training a bit
TrainingArguments(gradient_accumulation_steps=4)
- Samples/second: 8.86 -> 7.75
- GPU memory occupied: 14949 MB -> 8681 MB
Gradient Checkpointing
- Different gradient storage strategies
- Save all: all activations from the forward pass are normally saved
- Forget all: recompute all on demand during the backward pass
- Gradient checkpointing: saves strategically selected activations throughout the computational graph
TrainingArguments(gradient_checkpointing=True)
- Samples/second: 8.86 -> 5.99
- GPU memory occupied: 14949 MB -> 6775 MB
- A general rule: gradient checkpointing slows down training by about 20%
Mixed Precision Training
- Goal: speed up training and reduce memory usage
- Use: both FP16 and FP32 in training
- Speed up training
- Modern GPUs are designed to be more efficient at computing FP16
- A100
- Peak FP32: 19.5 TFLOPS
- Peak FP16: 78 TFLOPS
- Peak TF32 Tensor Core: 156 TFLOPS
- Peak FP16 Tensor Core: 312 TFLOPS
- Reduce memory usage
- Maintains model weights in both FP32 and FP16
- However, activations, gradients, and intermediate values in the forward and backward pass can be kept in FP16
- Gradients are computed in FP16 and converted back to FP32 for the optimization step
- FP32 copy of the weights allows precise updates during the optimization step
- Allows larger batch sizes
TrainingArguments(fp16=True)
- Samples/second: 8.86 -> 18.64
- GPU memory occupied: 14949 MB -> 13939 MB
- Combine Gradient Accumulation, Gradient Checkpointing, Mixed Precision Training => Slightly faster training with half VRAM
GPU Floating Data Types
- FP32 (float32)
- FP16 (float16)
- BF16 (bfloat16)
- Supported on Ampere or newer hardware
- Compared to FP16: worse precision, but significantly bigger dynamic range (
65535
vs3.39e+38
) - Prevent numeric overflow
TrainingArguments(bf16=True)
- TF32 (CUDA internal data type)
- Supported on Ampere or newer hardware
- Compared to FP32: worse precision, but same dynamic range
- Uses only 19 bits in total
- Up to 3x throughput improvement
torch.backends.cuda.matmul.allow_tf32 = True
TrainingArguments(tf32=True)
- Accelerating AI Training with NVIDIA TF32 Tensor Cores
- Majority of machine learning training shouldn’t be impacted
- The same perplexity and convergence as the FP32 training
- A100 benchmarks
Optimizer
- Adam or AdamW
- Fast convergence by storing the rolling average of the previous gradients
- Additional memory footprint of the order of the number of model parameters
- Fastest implementation: NVIDIA/apex
--optim adamw_apex_fused
- Adafactor
- Works well for some models but often it has instability issues
- 8bit BNB optimizer
- Quantization: stores the state with lower precision and dequantizes it only for the optimization
import bitsandbytes as bnb
bnb.optim.Adam8bit()
- Example: T5-3B model
- AdamW: 8 bytes/parameter, here the optimizer will need
(8*3) 24GB
of GPU memory - Adafactor: slightly more than 4 bytes, so
(4*3) 12GB
and then some extra - 8bit BNB: only
(2*3) 6GB
if all optimizer states are quantized
- AdamW: 8 bytes/parameter, here the optimizer will need
_multi_tensor
- Significantly speed up the optimizers for situations with lots of small feature tensors
- https://github.com/huggingface/transformers/issues/9965
Accelerate
- For distributed training
- Gradient Synchronization & PyTorch’s distributed module
Code:
accelerator = Accelerator(fp16=training_args.fp16)
model, optimizer, dataloader = accelerator.prepare(model, adam_bnb_optim, dataloader)
model.train()
for step, batch in enumerate(dataloader, start=1):
accelerator.backward(loss)
DataLoader
DataLoader(pin_memory=True, ...)
: data preloaded into the pinned memory on CPUDataLoader(num_workers=4, ...)
: for low CPU utilization
DeepSpeed ZeRO
- TBD
Using PyTorch native attention
torch.nn.functional.scaled_dot_product_attention
- Use fused GPU kernels as memory-efficient attention and flash attention
- Requires:
optimum
package- https://huggingface.co/docs/optimum/index
model = model.to_bettertransformer()
Model size estimation
- 1 MiB = 1048576 bytes
- 1 FP32 = 4 bytes
- VRAM estimate (MB) = n_params * 24 / 1048576
- Actural VRAM required
- GPU RAM for pytorch session only
- GPU RAM including extra driver buffer
- Reference