Note
This article was originally published on Medium and has been adapted for Red Hat Developer.
How does PyTorch calculate gradients when you run a backward() call? The answer is the autograd engine. This tool provides automatic differentiation, which calculates the gradients required for training deep learning models in PyTorch.
This article examines how the autograd engine builds and executes the computational graph, manages memory during backpropagation, and optimizes performance. Whether you are a researcher debugging gradient flows or an engineer optimizing training pipelines, understanding the internals of the autograd engine will improve your PyTorch skills.
What is autograd?
The autograd (automatic gradient) engine is PyTorch's automatic differentiation tool that:
- Records operations during the forward pass to build a computational graph.
- Automatically calculates gradients using the chain rule during the backward pass.
- Manages memory efficiently by saving only necessary intermediate values.
- Optimizes the graph by pruning unnecessary computations.
The autograd engine tracks every mathematical operation on tensors and reverses these operations to compute gradients.
import torch
# Simple example
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1
y.backward()
print(x.grad) # Output: tensor(7.0) # dy/dx = 2*x + 3 = 2*2 + 3 = 7Behind this simple API lies a sophisticated system that we'll unpack piece by piece.
Computational graphs
A computational graph is a directed acyclic graph (DAG). In this graph, nodes represent operations or functions, and edges represent the flow of data via tensors. Leaves are the input tensors, while roots are the output tensors.
Example: y = (x + 2) * 3
Forward Graph:
x (leaf) → [+2] → temp → [*3] → y (root)
Backward Graph (reversed):
y.grad (dy/dy)= 1.0
→ [*3]'(means y = temp*3)
→ temp.grad = 3 (dy/d(temp))
→ [+2]' (means temp = x + 2 -> d(temp)/dx = 1)
→ x.grad = 3 (dy/dx)Forward pass vs. backward pass
During the forward pass, the engine executes the computation y = f(x), builds the computational graph dynamically, and saves the intermediate values required for the backward pass.
During the backward pass, the engine traverses the graph in reverse order. It applies the chain rule to compute gradients—defined as dx = dy * df/dx—and accumulates those gradients at the leaf nodes.
# Forward: build graph + compute output
x = torch.tensor(3.0, requires_grad=True)
y = x * 2 # grad_fn: MulBackward0
z = y + 1 # grad_fn: AddBackward0
loss = z ** 2 # grad_fn: PowBackward0
# Backward: traverse graph + compute gradients
loss.backward()
print(x.grad) # tensor(28.0)
# Computation: dloss/dx = dloss/dz * dz/dy * dy/dx = 2*z * 1 * 2 = 2*7*2 = 28Tape-based autograd vs. define-by-run
There are two main paradigms for automatic differentiation: tape-based autograd and define-by-run.
Traditional tape-based autograd uses a static and predefined graph structure that must be compiled before execution. While this allows the engine to optimize the entire graph ahead of time, it makes debugging difficult because of the separation between the graph and the runtime. This approach also makes dynamic control flow, such as if or while statements, challenging to implement.
# Pseudo-code for tape-based approach
# Define the graph structure FIRST
graph = ComputationalGraph()
x = graph.placeholder()
y = graph.multiply(x, 2)
z = graph.add(y, 1)
loss = graph.power(z, 2)
# Compile the graph
session = graph.compile()
# THEN execute with actual data
result = session.run(loss, feed_dict={x: 3.0})
gradients = session.run(grad(loss, x), feed_dict={x: 3.0})PyTorch uses the define-by-run approach, where the graph is built dynamically during the forward pass. This means the engine differentiates exactly what you run, supporting arbitrary Python control flow naturally.
This method is easy to debug because of eager execution, and the graph structure can change between iterations. However, dynamic graph construction does introduce a slight performance overhead.
# PyTorch: Define and run simultaneously
x = torch.tensor(3.0, requires_grad=True)
y = x * 2 # Graph is built HERE during execution
z = y + 1 # Graph continues to build
loss = z ** 2 # Still building...
# The graph exists NOW and can be inspected
print(loss.grad_fn) # <PowBackward0>
print(loss.grad_fn.next_functions) # ((<AddBackward0>, 0),)
# Execute backward pass
loss.backward() The define-by-run approach allows you to use standard Python features, such as loops and conditional statements, to build a model.
def forward(x, n):
result = x
for i in range(n): # n can change at runtime!
if result.sum() > 0:
result = result * 2
else:
result = result + 1
return result.sum()
# Different values of n create different graphs
x = torch.randn(3, requires_grad=True)
# Iteration 1: n=3 creates a 3-layer graph
loss1 = forward(x, n=3)
loss1.backward()
# Iteration 2: n=5 creates a 5-layer graph
loss2 = forward(x.detach().requires_grad_(), n=5)
loss2.backward()In this example, the forward function uses a for loop and an if statement to determine the path of the data. Because PyTorch builds the graph as the code runs, the autograd engine can differentiate through these dynamic structures. You do not need to predefine every possible path before running the model.
This flexibility supports complex architectures, such as recurrent networks with variable sequence lengths, recursive networks with tree structures. It also simplifies neural architecture search and experimental research by allowing the model structure to change during execution.
How PyTorch builds the backward graph
PyTorch constructs the backward graph by tracking mathematical operations through specific data structures. When you perform a calculation on a tensor, the engine records the operation to ensure it can calculate gradients later.
The Node class
Every operation in the computational graph is represented by a Node object, which PyTorch often refers to as a grad_fn. This object acts as a functional unit that knows how to compute the local gradient for a specific operation.
// Simplified from torch/csrc/autograd/function.h
struct Node {
std::vector<Edge> next_edges_; // Connections to previous nodes
uint64_t sequence_nr_; // For topological sorting
// The backward function: converts output gradients to input gradients
virtual variable_list apply(variable_list&& inputs) = 0;
};
// Edge connects nodes in the graph
struct Edge {
std::shared_ptr<Node> function; // The node to call
uint32_t input_nr; // Which input of that node
};In this C++ structure, the next_edges_ vector contains Edge objects that link the current Node to the previous operations in the sequence. The apply method is the core function that executes the actual gradient calculation during the backward pass.
Graph structure and visualization
You can inspect the graph structure in Python by accessing the grad_fn attribute of a tensor. This attribute provides a handle to the Node at the root of the graph. By following the next_functions attribute, you can trace the computation history back to the leaf tensors.
x = torch.tensor(2.0, requires_grad=True)
y = x * 3
z = y + 2
loss = z ** 2
print(f"loss.grad_fn: {loss.grad_fn}")
# <PowBackward0>
print(f"Next: {loss.grad_fn.next_functions}")
# ((<AddBackward0>, 0), (None, 0))
print(f"Next-Next: {loss.grad_fn.next_functions[0][0].next_functions}")
# ((<MulBackward0>, 0), (None, 0))The resulting graph is a directed acyclic graph where the loss tensor is the root. The engine traverses this structure in reverse to move from the output back to the input tensors.
x (leaf, requires_grad=True)
|
| [MulBackward0: y = x * 3]
|
y (intermediate)
|
| [AddBackward0: z = y + 2]
|
z (intermediate)
|
| [PowBackward0: loss = z ** 2]
|
loss (root)In this structure, the intermediate tensors y and z connect the operations. Each time you call an operation like multiplication or addition, PyTorch creates a new Node and updates the links between them.
Forward pass: Recording operations
When you perform operations on tensors with requires_grad=True, PyTorch executes the operation, such as matrix multiplication. The engine then creates a Node for the backward pass, such as MmBackward0, and attaches it to the output tensor using the grad_fn property. Finally, the engine saves the necessary inputs for the backward computation.
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
c = torch.matmul(a, b)
print(c.grad_fn) # <MmBackward0>
# This node knows:
# - How to compute gradients: d(matmul)/da and d(matmul)/db
# - What it needs: saved_a (for grad_b) and saved_b (for grad_a)
# - Where to send gradients: edges to 'a' and 'b'The following simplified C++ logic shows how PyTorch records these operations internally::
// When you call torch.matmul(a, b)
Tensor matmul(const Tensor& a, const Tensor& b) {
// 1. Compute the actual result
Tensor result = at::matmul(a, b); // ATen operation
// 2. If any input requires grad, create backward node
if (a.requires_grad() || b.requires_grad()) {
auto grad_fn = std::make_shared<MmBackward0>();
// 3. Save inputs needed for backward
grad_fn->save_for_backward({a, b});
// 4. Connect to inputs via edges
grad_fn->set_next_edges({
Edge(a.grad_fn(), a.output_nr()),
Edge(b.grad_fn(), b.output_nr())
});
// 5. Attach to output
result.set_grad_fn(grad_fn);
}
return result;
}Backward pass: Traversing the graph
When you call .backward(), PyTorch starts at the root output tensor and performs a topological sort to determine the execution order. The engine then calls the backward function for each node in reverse order and accumulates the gradients at the leaf nodes.
# Example: Detailed backward pass
x = torch.tensor(2.0, requires_grad=True)
y = x * 3 # MulBackward0
z = y + 2 # AddBackward0
loss = z ** 2 # PowBackward0
loss.backward()
# What happens:
# Step 1: loss.backward() starts with grad_output = 1.0
#
# Step 2: PowBackward0.apply(grad_output=1.0)
# Formula: d(z^2)/dz = 2*z
# z.grad = 1.0 * 2 * z = 1.0 * 2 * 8 = 16.0
#
# Step 3: AddBackward0.apply(grad_output=16.0)
# Formula: d(y+2)/dy = 1
# y.grad = 16.0 * 1 = 16.0
#
# Step 4: MulBackward0.apply(grad_output=16.0)
# Formula: d(x*3)/dx = 3
# x.grad = 16.0 * 3 = 48.0
print(x.grad) # tensor(48.0) ✓The engine: Orchestrating backward pass
The engine (defined in torch/csrc/autograd/engine.cpp) is responsible for executing the backward pass:
// Simplified engine logic
class Engine {
variable_list execute(
const edge_list& roots, // Starting nodes
const variable_list& grad_inputs // Initial gradients
) {
// 1. Topological sort of the graph
auto sorted_nodes = topological_sort(roots);
// 2. Initialize gradient buffers
std::unordered_map<Node*, variable_list> grad_map;
// 3. Execute nodes in reverse topological order
for (auto& node : sorted_nodes) {
// Get accumulated gradient for this node
auto grad_inputs = grad_map[node];
// Call the node's backward function
auto grad_outputs = node->apply(std::move(grad_inputs));
// Send gradients to next nodes
for (size_t i = 0; i < node->next_edges_.size(); ++i) {
auto& edge = node->next_edges_[i];
if (edge.function) {
grad_map[edge.function.get()].push_back(grad_outputs[i]);
}
}
}
return collect_leaf_grads();
}
};SavedTensor mechanism
The autograd engine must efficiently manage saved tensors, which are the intermediate values required for backward computation.
Why save tensors? Many operations require specific inputs or intermediate values to compute gradients. The following table describes what the engine saves for common operations.
| Operation | Forward pass | Backward pass | What to save |
| Square | y = x * x | dx = 2 * x * dy | Save x |
| Exponential | y = e^x | dx = e^x * dy = y * dy | Save y (output) |
| Matrix multiply | y = x @ w | dx = dy @ w.T dw = x.T @ dy | Save both x and w |
| ReLU | y = \max(0, x) | dx = (x > 0) * dy | Save x > 0 mask |
The SavedVariable class
PyTorch uses the SavedVariable class to store tensors:
// From torch/csrc/autograd/saved_variable.h
class SavedVariable {
private:
// The actual tensor data
at::Tensor data_;
// Metadata for reconstruction
std::shared_ptr<Node> grad_fn_;
uint32_t output_nr_;
uint32_t version_counter_; // To detect in-place modifications
// Optional hooks for custom save/load behavior
std::unique_ptr<SavedVariableHooks> hooks_;
public:
// Constructor: saves a variable during forward pass
SavedVariable(const Variable& variable, bool is_output);
// Unpack: reconstructs the variable during backward pass
Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
};Memory optimization: What gets saved?
PyTorch optimizes memory by saving outputs instead of inputs when possible:
# ReLU: can reconstruct gradient from output
# Instead of saving input x, save output y
def relu_backward(y, grad_output):
return (y > 0).float() * grad_output # Uses output, not input!
It saves views instead of copies:
# For operations like transpose, save metadata not data
x = torch.randn(1000, 1000, requires_grad=True)
y = x.t() # Transpose
# Internally: saves stride/offset info, NOT a copy
# Memory: ~0 bytes extra (just metadata)
It also reduces memory usage by recomputing inexpensive operations during the backward pass:
# For cheap operations like element-wise functions (GELU) ,
# sometimes it's better to recompute in backward than save
def gelu_backward(x, grad_output):
# Could save intermediate values, but recomputing is cheaper
# than memory transfer in many cases
return grad_output * gelu_gradient(x) # Recomputes parts of forwardSavedTensor hooks: Custom memory management
PyTorch lets you customize how tensors are saved:
# Example: Save tensors to CPU to reduce GPU memory
def pack_hook(tensor):
return tensor.cpu() # Move to CPU
def unpack_hook(packed):
return packed.cuda() # Move back to GPU
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# Your forward pass here
x = torch.randn(10000, 10000, device='cuda', requires_grad=True)
y = expensive_operation(x)
loss = y.sum()
loss.backward() # Tensors are moved back from CPU for backwardActivation checkpointing
Activation checkpointing reduces memory usage during training by trading computation time for memory space. Instead of storing all intermediate activations during the forward pass, PyTorch recomputes them during the backward pass.
# Trade compute for memory by recomputing activations
def checkpoint_function(function, *args):
"""
During forward: Don't save intermediate activations
During backward: Rerun forward to recompute them
"""
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, *args):
ctx.run_function = run_function
ctx.save_for_backward(*args) # Only save inputs
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *grad_outputs):
# Recompute forward to get intermediate values
inputs = ctx.saved_tensors
with torch.enable_grad():
outputs = ctx.run_function(*inputs)
# Now compute gradients with recomputed values
torch.autograd.backward(outputs, grad_outputs)
return (None,) + tuple(inp.grad for inp in inputs)
return CheckpointFunction.apply(function, *args)
# Usage: Train huge models with limited memory
for layer in huge_model.layers:
x = checkpoint_function(layer, x) # Recompute activations in backwardIn this example, the CheckpointFunction uses ctx.save_for_backward to store only the input tensors. During the backward pass, the engine calls the forward function again within a torch.enable_grad context to recreate the intermediate values. This approach is effective for training large models, such as deep neural networks with hundreds of layers, on hardware with limited memory.
Version counters: Detecting invalid in-place operations
PyTorch uses version counters to detect when an in-place modification interferes with a gradient calculation:
# Problem: in-place modification after saving
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2 # y's backward will need x
x.add_(1) # DANGER: modifying x after it was saved
# This will raise an error during backward:
# "RuntimeError: one of the variables needed for gradient computation
# has been modified by an inplace operation"
try:
y.backward(torch.ones_like(y))
except RuntimeError as e:
print(e)The engine increments a version counter whenever an in-place operation occurs to ensure data integrity during backpropagation. The following C++ structures demonstrate how the engine validates these versions:
// Each tensor has a version counter
struct TensorImpl {
uint32_t version_counter_ = 0;
void bump_version() {
version_counter_++;
}
};
// SavedVariable stores the version at save time
class SavedVariable {
uint32_t saved_version_;
Variable unpack() {
TORCH_CHECK(
tensor.version_counter() == saved_version_,
"Variable has been modified by an in-place operation"
);
return tensor;
}
};Gradient accumulation
Gradient accumulation adds multiple gradients to the same parameter. This process occurs when a variable is used in multiple paths of a graph or when you accumulate gradients across several batches using multiple backward passes.
Multiple paths: Automatic accumulation
Automatic accumulation occurs when a variable is used in multiple operations within the same computational graph.
# Example: x is used twice
x = torch.tensor(2.0, requires_grad=True)
y = x * 2 # First use
z = x + 3 # Second use
loss = y + z # Both paths contribute to x.grad
loss.backward()This article examines how the autograd engine builds and executes the computational graph, manages memory during backpropagation, and optimizes performance.In th
# x.grad = dy/dx + dz/dx = 2 + 1 = 3
print(x.grad) # tensor(3.0)Graph visualization:
x (requires_grad=True)
/ \
/ \
[*2] [+3] <- Two different operations using x
| |
y z
\ /
\ /
[+]
|
lossDuring the backward pass, gradients from both paths are accumulated at x.
AccumulateGrad node
Leaf variables, such as model parameters, use a specific node called AccumulateGrad to handle incoming gradients.
// From torch/csrc/autograd/functions/accumulate_grad.h
class AccumulateGrad : public Node {
Variable variable; // Reference to the parameter
variable_list apply(variable_list&& grads) override {
auto& grad = grads[0];
// First gradient: just assign
if (!variable.grad().defined()) {
variable.mutable_grad() = grad;
}
// Subsequent gradients: accumulate (add)
else {
variable.mutable_grad() += grad; // THIS IS THE KEY!
}
return {}; // Leaf nodes don't propagate further
}
};Multi-batch gradient accumulation
You can use gradient accumulation to simulate a larger batch size when hardware memory is limited.
# Simulate batch_size=128 with 4 batches of size 32
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
accumulation_steps = 4
optimizer.zero_grad()
for i, (x, y) in enumerate(dataloader): # batch_size=32
# Forward pass
output = model(x)
loss = criterion(output, y) / accumulation_steps # Scale the loss!
# Backward pass: gradients accumulate in parameter.grad
loss.backward()
# Only update every 4 batches
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()Scaling the loss for consistency
Why scale the loss? When you accumulate gradients across multiple batches, you must scale the loss to maintain a consistent gradient average.
# Without scaling:
# total_grad = grad_batch1 + grad_batch2 + grad_batch3 + grad_batch4
# This is 4x larger than a single batch!
# With scaling (loss / 4):
# total_grad = grad_batch1/4 + grad_batch2/4 + grad_batch3/4 + grad_batch4/4
# This equals the average, same as a single batch of size 128Thread safety in gradient accumulation
The autograd engine uses synchronization primitives to ensure data integrity during multithreaded execution.
// From torch/csrc/autograd/functions/accumulate_grad.cpp
void AccumulateGrad_apply_impl(
variable_list&& grads,
at::Tensor& variable,
at::Tensor& variable_grad,
std::mutex* mutex = nullptr
) {
// Acquire lock for thread safety
std::optional<std::lock_guard<std::mutex>> lock;
if (mutex != nullptr) {
lock.emplace(*mutex);
}
// Safely accumulate gradient
if (variable_grad.defined()) {
variable_grad += new_grad;
} else {
variable_grad = new_grad;
}
}Distributed gradient accumulation
In a distributed environment, the engine coordinates gradient summation across multiple devices.
# PyTorch DDP (DistributedDataParallel) example
model = torch.nn.parallel.DistributedDataParallel(model)
# During backward:
# 1. Each GPU computes local gradients
# 2. All-reduce operation sums gradients across GPUs
# 3. Each GPU gets the averaged gradient
loss.backward() # Triggers all-reduce internally
optimizer.step()Graph pruning and optimization
The autograd engine uses several optimization techniques to improve performance and reduce memory usage. These features allow PyTorch to handle large models while maintaining high execution speeds.
Pruning unused paths
Operations on z don't create backward nodes, which saves memory.
# Example: Pruning demonstration
x = torch.randn(5, requires_grad=True)
y = torch.randn(5, requires_grad=True)
z = torch.randn(5, requires_grad=False) # No gradient needed
a = x * 2
b = y + 1
c = z * 3 # This path won't have grad_fn!
loss = a.sum() + b.sum()
print(a.grad_fn) # <MulBackward0>
print(b.grad_fn) # <AddBackward0>
print(c.grad_fn) # None <- Pruned!
loss.backward()
print(x.grad) # tensor([2., 2., 2., 2., 2.])
print(y.grad) # tensor([1., 1., 1., 1., 1.])
print(z.grad) # None <- Never computedIn this example, the engine ignores the path for the variable z. This saves memory by excluding unnecessary Node objects from the backward graph.
Fusing operations
Operation fusion combines multiple mathematical operations into a single execution step, which reduces the size of the graph. This process minimizes the overhead of managing many small backward nodes and can significantly improve execution speed.
# Multiple element-wise operations can be fused
x = torch.randn(1000, 1000, requires_grad=True)
# Without fusion: 3 separate backward nodes
y = x + 1 # AddBackward0
z = y * 2 # MulBackward0
w = z.relu() # ReluBackward0
# With TorchScript or JIT: single fused kernel
@torch.jit.script
def fused_ops(x):
return (x + 1) * 2.relu() # Single optimized operation
w = fused_ops(x) # More efficient!In-place operations
In-place operations modify a tensor directly instead of creating a new one. This saves memory because the engine does not need to allocate extra space for an output tensor. However, you must use in-place operations carefully because they can overwrite data that the engine needs for the backward pass.
x = torch.randn(1000, 1000, requires_grad=True)
# Out-of-place: creates a new tensor
y = x + 1 # Memory: 2x (x and y)
# In-place: modifies x directly (with care!)
x.add_(1) # Memory: 1x (only x)
# But be careful: in-place on computational graph nodes causes errors!
x = torch.randn(5, requires_grad=True)
y = x * 2
x.add_(1) # RuntimeError: modified by in-place after being saved!While in-place operations are efficient, modifying a tensor after it has been recorded in the graph triggers a RuntimeError. In-place operations are generally safe when applied to leaf variables between training steps or within a torch.no_grad context.
# Safe: in-place on leaf variables (parameters) between forward passes
model = nn.Linear(10, 10)
model.weight.add_(0.01) # OK: modifying parameter directly
# Safe: in-place in no_grad context
with torch.no_grad():
model.weight.add_(0.01) # OK: autograd is disabledGraph cleanup and retention
By default, PyTorch frees the graph after the backward pass to save memory. If you try to call .backward() a second time on the same output without intervention, the engine will raise an error because the graph no longer exists.
x = torch.randn(3, requires_grad=True)
y = x * 2
loss = y.sum()
# First backward: graph is freed by default
loss.backward()
# Second backward: ERROR!
try:
loss.backward()
except RuntimeError as e:
print(e) # "Trying to backward through the graph a second time..."To run multiple backward passes from the same output or to calculate higher-order gradients, you must set the retain_graph parameter to True. This keeps the intermediate nodes in memory for subsequent calculations.
# Use case 1: Multiple backward passes from the same output
loss1.backward(retain_graph=True) # Keep graph
loss2 = another_operation(y)
loss2.backward() # Now can use y again
# Use case 2: Computing Hessian (second derivatives)
x = torch.randn(3, requires_grad=True)
y = (x ** 2).sum()
grad_x = torch.autograd.grad(y, x, create_graph=True)[0]
# Need to backward through grad_x, so graph must be retained
hessian = torch.autograd.grad(grad_x.sum(), x)[0]Gradient checkpointing (rematerialization)
Gradient checkpointing, also known as rematerialization, allows you to trade computation for memory. Instead of saving every intermediate activation during the forward pass, the engine only saves a subset of them. It recomputes the missing activations as needed during the backward pass.
from torch.utils.checkpoint import checkpoint
class DeepModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1000, 1000) for _ in range(100) # 100 layers!
])
def forward(self, x):
# Without checkpointing: saves all 100 intermediate activations
# Memory: O(100 * 1000 * 1000) = ~400 MB
# With checkpointing: only saves every 10th activation
for i, layer in enumerate(self.layers):
if i % 10 == 0:
x = checkpoint(layer, x) # Recompute in backward
else:
x = layer(x)
return x
# Memory usage: ~40 MB instead of ~400 MB
# Compute cost: ~10% increase (recomputing 90% of layers)Optimization through compilation
Modern versions of PyTorch offer graph optimizations through the torch.compile function.
# TorchScript: ahead-of-time graph optimization
@torch.jit.script
def optimized_function(x, y):
a = x + y
b = a * 2
c = b.relu()
return c
# Benefits:
# - Dead code elimination
# - Constant folding
# - Operation fusion
# - Memory planning
# TorchInductor (PyTorch 2.0+): even more aggressive optimization
model = torch.compile(model)
# - Kernel fusion
# - Auto-tuning
# - GPU-specific optimizationsExample: A complete forward and backward pass
Let's trace a complete neural network training step:
import torch
import torch.nn as nn
# Simple network
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(4, 2)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# Training setup
model = SimpleNet()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Training step
x = torch.randn(1, 3)
target = torch.randn(1, 2)
# ===== FORWARD PASS =====
# Graph is built dynamically during execution
# Layer 1: Linear
# x -> [LinearBackward0: y = x @ W1.T + b1]
h1 = model.fc1(x)
print(f"h1.grad_fn: {h1.grad_fn}")
# Output: <AddmmBackward0>
# Saved: x, W1 (needed for backward)
# Activation: ReLU
# h1 -> [ReluBackward0: h2 = max(0, h1)]
h2 = torch.relu(h1)
print(f"h2.grad_fn: {h2.grad_fn}")
# Output: <ReluBackward0>
# Saved: h2 > 0 (mask, just bool, very memory efficient!)
# Layer 2: Linear
# h2 -> [LinearBackward0: output = h2 @ W2.T + b2]
output = model.fc2(h2)
print(f"output.grad_fn: {output.grad_fn}")
# Output: <AddmmBackward0>
# Saved: h2, W2
# Loss: MSE
# output, target -> [MseLossBackward0: loss = mean((output - target)^2)]
loss = criterion(output, target)
print(f"loss.grad_fn: {loss.grad_fn}")
# Output: <MseLossBackward0>
# Saved: output - target
# ===== BACKWARD PASS =====
optimizer.zero_grad() # Clear previous gradients
loss.backward() # Traverse graph in reverse
# What happens:
# 1. MseLossBackward0.apply(grad_output=1.0)
# -> Computes: grad_output = 2 * (output - target) / batch_size
#
# 2. AddmmBackward0.apply (fc2)
# -> Computes: grad_h2 = grad_output @ W2
# -> Computes: grad_W2 = h2.T @ grad_output
# -> Computes: grad_b2 = grad_output.sum(0)
# -> Accumulates in W2.grad and b2.grad
#
# 3. ReluBackward0.apply
# -> Computes: grad_h1 = grad_h2 * (h2 > 0)
#
# 4. AddmmBackward0.apply (fc1)
# -> Computes: grad_x = grad_h1 @ W1
# -> Computes: grad_W1 = x.T @ grad_h1
# -> Computes: grad_b1 = grad_h1.sum(0)
# -> Accumulates in W1.grad and b1.grad
# ===== OPTIMIZER STEP =====
optimizer.step() # Update parameters using accumulated gradients
# W1 -= lr * W1.grad
# b1 -= lr * b1.grad
# W2 -= lr * W2.grad
# b2 -= lr * b2.grad
print("\nGradients computed:")
print(f"fc1.weight.grad shape: {model.fc1.weight.grad.shape}")
print(f"fc1.bias.grad shape: {model.fc1.bias.grad.shape}")
print(f"fc2.weight.grad shape: {model.fc2.weight.grad.shape}")
print(f"fc2.bias.grad shape: {model.fc2.bias.grad.shape}")Advanced topics
Use these advanced features to extend the autograd engine with custom differentiation logic or to compute higher-order gradients for specialized mathematical models.
Custom autograd functions
You can define custom operations with custom backward logic:
class CustomExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# Forward computation
result = torch.exp(input)
ctx.save_for_backward(result) # Save for backward
return result
@staticmethod
def backward(ctx, grad_output):
# Custom backward computation
result, = ctx.saved_tensors
return grad_output * result # d(exp(x))/dx = exp(x)
# Usage
x = torch.tensor(1.0, requires_grad=True)
y = CustomExp.apply(x)
y.backward()
print(x.grad) # tensor(2.7183) = e^1 Higher-order gradients
PyTorch supports gradients of gradients, such as Hessians:
x = torch.tensor(2.0, requires_grad=True)
y = x ** 3 # y = x^3
# First derivative: dy/dx = 3x^2
grad_x = torch.autograd.grad(y, x, create_graph=True)[0]
print(grad_x) # tensor(12.) = 3 * 2^2
# Second derivative: d²y/dx² = 6x
grad2_x = torch.autograd.grad(grad_x, x)[0]
print(grad2_x) # tensor(12.) = 6 * 2Jacobian and Hessian computation
The following example shows how to calculate a Jacobian matrix by calling the torch.autograd.grad function for each dimension of the output tensor:
def compute_jacobian(func, x):
"""Compute Jacobian matrix of func at x."""
x = x.requires_grad_()
y = func(x)
jacobian = torch.zeros(y.shape[0], x.shape[0])
for i in range(y.shape[0]):
grad_output = torch.zeros_like(y)
grad_output[i] = 1
jacobian[i] = torch.autograd.grad(y, x, grad_output, retain_graph=True)[0]
return jacobian
# Example
def f(x):
return torch.stack([x[0]**2 + x[1], x[0] * x[1]**2])
x = torch.tensor([2.0, 3.0])
J = compute_jacobian(f, x)
print("Jacobian:")
print(J)
# tensor([[ 4., 1.], <- ∂f1/∂x1=2*x1, ∂f1/∂x2=1
# [ 9., 12.]]) <- ∂f2/∂x1=x2^2, ∂f2/∂x2=2*x1*x2Gradient hooks
Use gradient hooks to intercept and modify gradients during the backward pass:
# Global hook: called for every parameter
def gradient_clipper(grad):
return torch.clamp(grad, -1, 1)
x = torch.tensor(5.0, requires_grad=True)
y = x ** 2
# Register hook
handle = x.register_hook(gradient_clipper)
y.backward()
print(x.grad) # tensor(1.0) instead of tensor(10.0)
# Because d(x^2)/dx = 2x = 10, but clipped to [-1, 1]
handle.remove() # Clean up Graph visualization
Visualize the computational graph:
from torchviz import make_dot
x = torch.randn(3, requires_grad=True)
y = x * 2
z = y + 1
loss = z.sum()
# Generate graph visualization
dot = make_dot(loss, params={'x': x})
dot.render('computation_graph', format='png')Performance tips and best practices
Use these strategies to optimize your code for speed and memory efficiency when working with the autograd engine.
Memory management
Efficiently clearing gradients helps prevent memory leaks and ensures your model trains correctly. While you can manually set gradients to None, using the built-in methods in the optimizer is more reliable.
# ✅ Good: Clear gradients properly
optimizer.zero_grad()
# ❌ Bad: Manual zeroing (misses optimizer state)
for param in model.parameters():
param.grad = None
# ✅ Good: Use set_to_none for slight speed up
optimizer.zero_grad(set_to_none=True)Avoid unnecessary graph construction
You can reduce computational overhead by disabling the autograd engine when you do not need to calculate gradients. This is common during model evaluation or when generating predictions.
# When you don't need gradients:
with torch.no_grad():
predictions = model(test_data) # No graph built!
# Or use inference mode (even faster):
with torch.inference_mode():
predictions = model(test_data)Detach tensors to stop gradient flow
The detach method allows you to stop the flow of gradients through specific parts of your computational graph. This is useful for freezing certain layers or when you want to use a tensor as a constant value in a separate calculation.
# Detach to stop gradient flow
x = torch.randn(3, requires_grad=True)
y = x * 2
z = y.detach() * 3 # z won't propagate gradients to y
loss = z.sum()
loss.backward()
print(y.grad) # None (gradient flow stopped)Mixed precision training
Mixed precision training uses different numerical formats, such as 16-bit and 32-bit floating-point types, to speed up calculations. This technique reduces memory consumption on compatible hardware without sacrificing model accuracy.
from torch.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()
for data, target in dataloader:
optimizer.zero_grad()
# Forward in mixed precision
with autocast(device_type='cuda'):
output = model(data)
loss = criterion(output, target)
# Backward with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Debugging autograd
Debugging the autograd engine requires specific tools to identify numerical instabilities and structural errors in the computational graph. Use the following methods to troubleshoot your gradient calculations.
Detecting anomalies in gradients
You can use the set_detect_anomaly function to locate the source of NaN or Inf values in your model. When you enable this mode, the engine tracks the forward pass operations and provides a traceback to the specific line of code where a numerical error first appeared.
torch.autograd.set_detect_anomaly(True)
try:
loss = model(x)
loss.backward()
except RuntimeError as e:
print(f"Anomaly detected: {e}")Visualizing gradient flow
Monitoring the magnitude of gradients across different layers helps you identify issues such as vanishing or exploding gradients. The following function calculates the average gradient for each parameter to help you visualize how the information flows through the network during the backward pass.
def plot_grad_flow(named_parameters):
"""Plot gradient flow through the network."""
ave_grads = []
layers = []
for n, p in named_parameters:
if p.requires_grad and p.grad is not None:
layers.append(n)
ave_grads.append(p.grad.abs().mean().item())
import matplotlib.pyplot as plt
plt.plot(ave_grads, alpha=0.3, color="b")
plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k")
plt.xticks(range(0,len(ave_grads)), layers, rotation="vertical")
plt.xlabel("Layers")
plt.ylabel("Average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.show()
# After backward pass:
plot_grad_flow(model.named_parameters())Inspect the computational graph structure
You can verify the hierarchy of your model by recursively printing the grad_fn attributes of a tensor. This technique allows you to confirm that the engine is recording operations in the correct order and that every path connects back to a leaf node.
def print_graph(output, level=0):
"""Recursively print the computational graph."""
if hasattr(output, 'grad_fn'):
print(' ' * level, output.grad_fn)
if hasattr(output.grad_fn, 'next_functions'):
for fn in output.grad_fn.next_functions:
if fn[0] is not None:
print_graph(fn[0], level + 1)
x = torch.randn(3, requires_grad=True)
y = x * 2
z = y + 1
loss = z.sum()
print_graph(loss)
# Output:
# <SumBackward0>
# <AddBackward0>
# <MulBackward0>
# <AccumulateGrad>Key takeaways for the autograd engine
The PyTorch autograd engine balances several technical factors to support developer workflows. Its define-by-run approach provides flexibility by supporting standard Python control flow. The engine also ensures efficiency through memory management and graph optimizations.
While the .backward() method provides a simple interface, it manages the complex internal operations required for differentiation. Additionally, eager execution improves debuggability by making problems easier to trace during development.
When you call loss.backward(), PyTorch performs several operations in milliseconds. The engine traverses the dynamically built computational graph and unpacks saved tensors while performing version checks. It also accumulates gradients through multiple paths and manages edge cases, such as in-place operations. Finally, the process optimizes memory by freeing the graph.
You now have a practical understanding of the internal functions of the autograd engine. To learn more, explore the following links: