← Back to Library
Training Technique Provider: Research

Gradient Checkpointing

Gradient checkpointing (activation checkpointing) is a memory optimization technique for neural network training that trades computation for memory. By recomputing activations during backward pass instead of storing them, it reduces memory consumption from O(n) to O(√n) for n-layer networks. Originally proposed by Chen et al. (2016), gradient checkpointing enables training of models 2-10x larger on the same hardware with only 20-33% training time increase. As of October 2025, it's integrated into PyTorch, TensorFlow, and all major training frameworks, essential for training large language models and long-sequence transformers.

Gradient Checkpointing
training memory-optimization efficiency deep-learning

Overview

During neural network training, forward pass activations must be stored for backward pass gradient computation. For deep networks with long sequences, this requires massive memory (gigabytes per batch). Gradient checkpointing solves this by: (1) Only storing activations at selected checkpoints (every k layers), (2) Recomputing intermediate activations during backward pass from nearest checkpoint. Memory: Reduces from O(n) to O(√n) for n layers. Tradeoff: Increases compute by ~33% (each activation recomputed once). Result: Train 2-10x larger batch sizes or sequence lengths on same hardware. Essential for: Llama 3 70B training, long-context transformers (32K+ tokens), QLoRA fine-tuning, image models with high resolution.

Key Benefits

  • Memory reduction: 2-10x less activation memory
  • Larger batch sizes: Train with 2-4x larger batches
  • Longer sequences: Enable 32K-128K token contexts
  • Compute overhead: Only 20-33% slower training
  • No quality loss: Mathematically equivalent to standard training
  • Easy integration: Single flag in PyTorch/TensorFlow
  • Configurable: Choose checkpoint granularity (every layer, every N layers)
  • Essential for large models: Required for training 70B+ parameter models

Code Example

# PyTorch gradient checkpointing
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

# Method 1: Manual checkpointing for custom layers
class MyLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)
    
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

class MyModel(nn.Module):
    def __init__(self, dim, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([MyLayer(dim) for _ in range(n_layers)])
    
    def forward(self, x):
        for layer in self.layers:
            # Use checkpoint to save memory
            x = checkpoint(layer, x, use_reentrant=False)
        return x

model = MyModel(768, 24).cuda()
x = torch.randn(32, 512, 768).cuda()  # batch=32, seq=512, dim=768
output = model(x)
print(f"Output shape: {output.shape}")

# Method 2: Hugging Face Transformers (automatic)
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Training with much lower memory!
input_ids = torch.randint(0, 32000, (4, 4096)).cuda()  # Long sequence
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
loss.backward()  # Uses checkpointing automatically

# Method 3: PyTorch Transformer with checkpointing
from torch.nn import TransformerEncoderLayer, TransformerEncoder

class CheckpointedTransformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(d_model, nhead)
        self.transformer = TransformerEncoder(encoder_layer, num_layers)
        self.enable_checkpointing = True
    
    def forward(self, x):
        if self.enable_checkpointing and self.training:
            # Checkpoint every transformer layer
            return checkpoint(self.transformer, x, use_reentrant=False)
        return self.transformer(x)

model = CheckpointedTransformer().cuda()
x = torch.randn(512, 32, 512).cuda()  # seq=512, batch=32, dim=512
output = model(x)
print(f"Transformer output: {output.shape}")

Memory Savings Example

For Llama 3 8B training with 2048 sequence length: Without checkpointing: ~40GB activation memory, batch size 4. With checkpointing: ~12GB activation memory, batch size 16 (4x larger). Training time: 100 minutes → 130 minutes (+30%). Throughput: 4 samples/min → 12.3 samples/min (3x faster despite 30% overhead per sample). For GPT-2 size models (125M params): Memory reduction from 24GB → 8GB enables training on consumer GPUs (RTX 3090). The key insight: gradient checkpointing's memory savings enable much larger batches, which more than compensate for per-sample slowdown.

When to Use

  • Training large models: Essential for 13B+ parameter models
  • Long sequences: Enable 8K-128K token context windows
  • Limited GPU memory: Train on consumer GPUs (24GB)
  • Batch size constrained: Increase batch size 2-4x
  • Fine-tuning: Combined with QLoRA for maximum memory efficiency
  • High-resolution images: Vision transformers on large images
  • Multi-GPU training: Reduce per-GPU memory requirements
  • Acceptable 20-30% slowdown for 2-10x memory savings

Professional Integration Services by 21medien

21medien offers gradient checkpointing optimization services including custom checkpoint strategy design, memory profiling, training pipeline optimization, and large-scale model training setup. Our team helps organizations maximize GPU utilization through optimal checkpoint granularity selection, combined techniques (checkpointing + mixed precision + ZeRO), and performance benchmarking. Contact us for training optimization solutions.

Resources

Original paper: https://arxiv.org/abs/1604.06174 | PyTorch docs: https://pytorch.org/docs/stable/checkpoint.html | Hugging Face guide: https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing