Note
This article was originally published on Medium and has been adapted for Red Hat Developer.
If you have ever wondered what powers PyTorch's tensor operations across CPUs, GPUs, and various accelerators, the answer is ATen, the foundational tensor library for PyTorch. In my work on PyTorch engineering, I've seen firsthand how ATen's architecture provides the simplicity developers love and the performance production systems demand.
This article explores how ATen works and how it handles tensor operations across different hardware backends.
What is ATen?
ATen is short for A Tensor Library, a modest name for what is arguably one of the most advanced tensor computation frameworks available.
At its core, ATen is a C++ library that serves as the foundation for PyTorch. It provides the backend for hundreds of mathematical operations. The framework is device-agnostic and sends operations to backends such as the CPU, CUDA, or MPS. This layer exists before autograd. It handles tensor operations without automatic differentiation.
Think of ATen as the engine under the PyTorch hood. When you write tensor.matmul(other) in Python, you call the ATen C++ implementation.
The PyTorch architecture stack
To understand ATen's role, let's visualize PyTorch's architecture (Figure 1).

ATen sits right in the middle, bridging high-level operations with low-level hardware-specific implementations.
Core architecture components
The core components of the ATen library are the fundamental building blocks for PyTorch tensor operations. These components include a C++ class library for multi-dimensional arrays and a backend system for hardware-specific execution.
The Tensor class
The Tensor class in ATen uses intrusive reference counting. This is similar to std::shared_ptr but more efficient.
namespace at {
class **Tensor** {
c10::intrusive_ptr<TensorImpl> impl_;
public:
Tensor() = default;
// Operations like add, mul, matmul...
Tensor **add**(const Tensor& other) const;
Tensor **matmul**(const Tensor& other) const;
// ... hundreds more operations
};
}Key design decisions:
- Reference counting: Multiple
Tensorobjects can share the same underlying data (TensorImpl). When you doTensor b = a;, you're creating a new handle that points to the same data. - Lightweight views: Operations like
transpose()orview()create newTensorobjects that share storage but have different metadata, such as shape and stride. - Copy-on-write: Tensors are immutable by default, for shared storage-modifications trigger copies only when necessary.
TensorImpl
Users interact with Tensor, but TensorImpl does the heavy lifting:
struct **TensorImpl** {
// Storage: the actual data buffer
c10::Storage storage_;
// Metadata
c10::SmallVector<int64_t, 5> sizes_; // Shape \[2, 3, 224, 224\]
c10::SmallVector<int64_t, 5> strides_; // Memory layout
int64_t storage_offset_; // Offset in storage
// Type information
c10::ScalarType dtype_; // float32, int64, etc.
c10::Device device_; // CPU, CUDA:0, etc.
c10::Layout layout_; // Strided, Sparse, etc.
// Dispatch key set
c10::DispatchKeySet key_set_;
};This separation between Tensor (handle) and TensorImpl (implementation) is crucial for:
- Efficient memory management
- Supporting views without copying data
- Thread-safe reference counting
The dispatch system
ATen determines whether to execute an operation on CPU, CUDA, or the Apple Metal GPU using the dispatch system.
PyTorch declares every operator in the central native_functions.yaml file:
func: add.Tensor(Tensor self, Tensor other, \*, Scalar alpha=1) -> Tensor
variants: function, method
dispatch:
CPU: add_cpu
CUDA: add_cuda
MPS: add_mps
CompositeImplicitAutograd: add_implThis declarative approach provides a single source of truth for all operations. It automatically generates code for Python bindings and establishes clear dispatch rules for different backends.
Dynamic dispatch flow
When you call a.add(b), here's what happens:
- Python call: The user calls
a.add(b)in the Python environment. - C++ bindings: Auto-generated bindings from
native_functions.yamltransition the call to the C++ layer. - Dispatcher: The dispatcher examines the
DispatchKeySetof the tensor. - Kernel selection: The system selects the appropriate kernel, such as CPU or CUDA.
- Execution: The backend-specific implementation executes.
- Return: The system returns the result tensor.
Example in code:
// ATen dispatcher logic (simplified)
Tensor Tensor::add(const Tensor& other, Scalar alpha) const {
// Extract dispatch key from this tensor
DispatchKey key = this->key_set().highestPriorityTypeId();
// Dispatch to appropriate implementation
switch(key) {
case DispatchKey::CPU:
return add_cpu(\*this, other, alpha);
case DispatchKey::CUDA:
return add_cuda(\*this, other, alpha);
case DispatchKey::MPS:
return add_mps(\*this, other, alpha);
// ... more backends
}
}In reality, PyTorch uses a more sophisticated table-based dispatch system for performance, but the concept is the same.
Operator categories
ATen organizes operations into distinct categories: leaf, composite, native, and custom operations.
Leaf operations
Leaf operations are basic operations that have dedicated backend implementations:
// aten/src/ATen/native/cpu/AddKernel.cpp
Tensor **add_cpu**(const Tensor& self, const Tensor& other, Scalar alpha) {
// Vectorized CPU implementation using AVX/AVX2/AVX512
// ...
}Composite operations
Developers build composite operations from other operations to ensure they remain device-agnostic:
// Composite implementation - works on any device
Tensor **sigmoid_backward**(const Tensor& grad, const Tensor& output) {
return grad \* output \* (1 - output);
}Native operations
Native operations are the standard functions that ship with PyTorch under the aten:: namespace.
aten::matmulaten::conv2daten::relu
Custom operations
Custom operations are user-defined functions that the system registers at runtime:
TORCH_LIBRARY(my_ops, m) {
m.def("my_custom_op(Tensor self) -> Tensor");
}
**TORCH_LIBRARY_IMPL**(my_ops, CPU, m) {
**m**.**impl**("my_custom_op", my_custom_op_cpu);
}Directory structure
Understanding ATen's codebase organization helps navigate its complexity:
aten/
├── src/
│ ├── ATen/
│ │ ├── core/ # Core tensor abstractions
│ │ │ ├── Tensor.h
│ │ │ ├── TensorImpl.h
│ │ │ └── DispatchKey.h
│ │ ├── native/ # Operator implementations ★
│ │ │ ├── Add.cpp # Generic implementations
│ │ │ ├── cpu/ # CPU-specific (AVX, etc.)
│ │ │ ├── cuda/ # CUDA implementations
│ │ │ ├── mps/ # Apple Metal
│ │ │ ├── mkl/ # Intel MKL
│ │ │ ├── cudnn/ # NVIDIA cuDNN
│ │ │ └── native_functions.yaml # ★★ THE REGISTRY
│ │ ├── ops/ # Generated operator headers
│ │ └── TensorIterator.h # Efficient iteration engine
│ └── THC/ # Legacy CUDA support
└── tools/ # Code generation scriptsThe golden rule: When adding a new operator:
- Declare it in
native_functions.yaml - Implement kernels in
aten/src/ATen/native/ - Let code generation handle the rest
How TensorIterator works
One of ATen's most underrated components is TensorIterator, a sophisticated system for writing efficient, device-agnostic element-wise operations.
The problem
Writing a simple operation like c = a + b is deceptively complex:
- Different shapes (broadcasting)
- Different strides (memory layout)
- Different dtypes (type promotion)
- Different devices (CPU versus GPU)
- Vectorisation opportunities
The solution
TensorIterator handles all this complexity:
void **add_kernel**(TensorIteratorBase& iter) {
// TensorIterator handles:
// - Broadcasting
// - Type casting
// - Memory layout
// - Splitting work across threads/blocks
AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [\&] {
cpu_kernel(iter, [\](scalar_t a, scalar_t b) -> scalar_t {
return a + b; // Your logic is just this!
});
});
}What TensorIterator does:
- Broadcasting: Automatically handles shape mismatches (
[3, 1] + [3, 4]) - Type promotion: Converts
dtypesas needed (int + float → float) - Vectorization: Uses SIMD instructions (AVX512, NEON)
- Memory coalescing: Optimizes memory access patterns
- Parallelization: Splits work across CPU threads or GPU blocks
Strategies to improve performance
You can optimize tensor execution and reduce memory overhead by using techniques such as kernel fusion, in-place operations, and zero-copy views.
Kernel fusion
ATen supports kernel fusion to minimize memory bandwidth:
// Instead of: d = (a + b) \* c
// Which creates intermediate tensor
// Fused version:
Tensor **fused_add_mul**(const Tensor& a, const Tensor& b, const Tensor& c) {
auto iter = TensorIterator::binary_op(result, a, b);
// Single kernel computes both operations
cpu_kernel(iter, [c_val](scalar_t a, scalar_t b) {
return (a + b) \* c_val;
});
}In-place operations
ATen distinguishes between two types of operations:
add(a, b)Creates a new tensor and allocates memory.add_(a, b)Modifies tensorain-place without new memory (note the underscore).
The (a!) annotation in the native_functions.yaml file tells ATen this is an in-place operation.
func: **add**.**out**(Tensor self, Tensor other, \*, Scalar alpha=1,
Tensor(a!) out) -> Tensor(a!)View operations
Many operations are zero-copy views:
Tensor a = torch::randn({100, 100});
Tensor b = a.transpose(0, 1); // No data copy!
Tensor c = a.view({10, 10, 100}); // No data copy!
// b and c share storage with a, just different metadata
assert(a.data_ptr() == b.data_ptr());Backend support: The ecosystem
The ATen dispatch system makes supporting new hardware easier. Table 1 lists specialized backends and libraries supported across architectures.
| Backend | Hardware | Key libraries | Description |
|---|---|---|---|
| CPU | x86/ARM processors | MKL, oneDNN, Eigen | Standard CPU with SIMD vectorization (AVX, AVX2, AVX512, NEON) |
| CUDA | NVIDIA GPUs | cuBLAS, cuDNN, cuSPARSE, NCCL | Full GPU acceleration with custom CUDA kernels |
| MPS | Apple M1/M2/M3/M4 | Metal, MPS | Metal Performance Shaders for Apple Silicon |
| ROCm | AMD GPUs (MI100+) | hipBLAS, MIOpen, RCCL | HIP-based (CUDA-compatible) acceleration |
| XLA | Google TPUs | XLA compiler | Accelerated Linear Algebra compiler |
| Intel XPU | Intel Arc, Ponte Vecchio | oneMKL, oneDNN | oneAPI-based GPU acceleration |
| HPU | Habana Gaudi/Gaudi2 | Habana SDK | Custom AI accelerator for training |
| MTIA | Meta Training/Inference Accelerator | Internal | Meta's custom AI accelerator |
| Vulkan | Cross-platform GPU | Vulkan SDK | Portable GPU compute API |
| PrivateUse1/2/3 | Custom hardware | User-defined | User-extensible dispatch keys |
Adding a new backend
Here's the skeleton for supporting a custom accelerator:
// 1. Register dispatch key
namespace **at** {
**constexpr** **DispatchKey** **CustomAccelerator** = **DispatchKey**::**PrivateUse1**;
}
// 2. Implement operators
TORCH_LIBRARY_IMPL(aten, CustomAccelerator, m) {
m.impl("add.Tensor", add_custom_accelerator);
m.impl("matmul", matmul_custom_accelerator);
// ... register more ops
}
// 3. Implement kernel
Tensor add_custom_accelerator(const Tensor& self, const Tensor& other,
Scalar alpha) {
// Your hardware-specific implementation
custom_device_add(self.data_ptr(), other.data_ptr(), ...);
return result;
}Code generation
ATen uses code generation to automate repetitive tasks. The torchgen tool processes native_functions.yaml and generates several components:
- Python bindings for the automatic PyTorch API
- C++ headers for operator declarations
- Dispatcher registrations for routing tables
- Type stubs for IDE autocomplete
A single entry in the YAML file generates several outputs:
- The
torch.relu()Python function - The
at::relu()C++ function - The
Tensor.relu()method - Dispatcher registration for CPU and CUDA
- Type hints for IDEs
- func: relu(Tensor self) -> Tensor
dispatch:
CPU: relu_cpu (custom CPU implementations for algorithm or processing)
CUDA: relu_cuda (custom CUDA implementations)This automation allows PyTorch to support more than 2,000 operations without repetitive manual coding.
Matrix multiplication example
The following steps trace a complete matmul operation:
import torch
a = torch.randn(1000, 1000, device='cuda')
b = torch.randn(1000, 1000, device='cuda')
c = torch.matmul(a, b)What actually happens:
- Python calls
torch.matmul(a, b). - The C++ binding is located in
torch/csrc/autograd/generated/python_torch_functions.cpp. - The autograd wrapper uses
torch::matmul()to track the operation.with autograd tracking - The ATen dispatcher calls
at::matmul(). - ATen extracts the keys, and bKey extraction: Both tensors point to→
DispatchKey::CUDA. - The system performs a kKernel lookup in:
native_functions.yamlfor→matmul_cuda. - The iImplementation resides in:
aten/src/ATen/native/cuda/Blas.cpp. - The systemLibrary calls:
cublas_gemm()fromcuBLAS. - TheHardware execution: GPU executes the kernel execution .
- The operation returns a nReturn: New CUDA tensor containing thewith result.
Memory management: Storage and views
ATen's memory model is advanced yet intuitive.
The Storage class represents the data buffer:
class **Storage** {
void\* data\_ptr\_; // Raw memory pointer
size\_t size\_bytes\_; // Allocation size
Allocator\* allocator\_; // CPU malloc()/free(), CUDA (cudaMalloc), etc.
// Reference counted, Automatically track how many tensors share this storage
c10::intrusive\_ptr<StorageImpl> impl\_;
};Views allow zero-copy transformations:
a = torch.randn(100, 100)
b = a.t() # Transpose view
c = a\[10:20, :\] # Slice view
d = a.view(10, 10, 100) # Reshape view
# All share the same storage!, only Meta data is different
print(a.data\_ptr() == b.data\_ptr()) # True
# How it works:
struct TensorImpl {
Storage storage\_; // Shared data
int64\_t offset\_; // Start position
IntArrayRef sizes\_; // \[100, 100\] vs \[10, 10, 100\]
IntArrayRef strides\_; // Memory step sizes
};Strides are the key: they define how the system traverses memory to read tensor elements.
Advanced features
ATen includes specialized capabilities for autograd integration, sparse tensor support, and model quantization.
Autograd integration
ATen doesn't provide autograd, but it's designed for simple integration.
// In torch/csrc/autograd/
class **AddBackward** : public Node {
Tensor apply(Tensor grad\_output) override {
// Gradient of add is pass-through
return {grad\_output, grad\_output};
}
};Every forward operation in ATen can have an autograd Node attached.
Sparse tensors
ATen supports multiple sparse formats:
- COO (coordinate format):
SparseTensorImpl - CSR (compressed sparse row):
SparseCsrTensorImpl - BSR (block sparse row)
auto indices = torch::tensor({{0, 1, 1}, {2, 0, 2}});
auto values = torch::tensor({3, 4, 5});
auto sparse = torch::sparse\_coo\_tensor(indices, values, {2, 3});Quantized tensors
ATen also supports quantized computation:
// QTensor: quantized representation
struct **QTensorImpl** : public TensorImpl {
double scale\_; //resolution
int64\_t zero\_point\_; // Which integer represents 0.0?
// Data is int8 or uint8 instead of float\[\] or Higher bits
};This support enables efficient inference on edge devices.
Performance insights
ATen achieves high performance through several architectural optimizations. Zero-copy operations such as views and transposes minimize data movement, while kernel fusion and JIT compilation reduce memory traffic. The library also uses SIMD instructions for vectorization and a caching allocator for memory pooling. Finally, ATen integrates with hardware-specific libraries, including cuBLAS and MKL, to maximize hardware utilization.
Benchmarks
Benchmarks on a typical NVIDIA A100 GPU demonstrate that ATen performs near hardware limits:
matmul(4096x4096): ~20 TFLOPS (approaching hardware peak)conv2dwith cuDNN: ~90% of theoretical maximum- Element-wise operations: Memory bandwidth limited (not compute)
Common patterns for contributors
Contributors can follow established patterns when adding support for new operations or hardware backends.
Create a CPU operator
You can implement a CPU operator as follows:
// aten/src/ATen/native/MyOp.cpp
namespace at { namespace native {
Tensor **my\_op\_cpu**(const Tensor& self) {
auto iter = TensorIteratorConfig()
.add\_output(output)
.add\_input(self)
.build();
AT\_DISPATCH\_ALL\_TYPES(self.scalar\_type(), "my\_op", \[&\] {
cpu\_kernel(iter, \[\](scalar\_t x) -> scalar\_t {
return x \* x + 1; // Your logic
});
});
return output;
}
}}Create a CUDA operator
CUDA operators require a kernel function and a wrapper that calculates grid and block dimensions.
// aten/src/ATen/native/cuda/MyOp.cu
template <typename scalar\_t>
\_\_global\_\_ void **my\_op\_cuda\_kernel**(
scalar\_t\* output, const scalar\_t\* input, int64\_t n) {
int idx = blockIdx.x \* blockDim.x + threadIdx.x;
if (idx < n) {
output\[idx\] = input\[idx\] \* input\[idx\] + 1;
}
}
Tensor **my\_op\_cuda**(const Tensor& self) {
auto output = torch::empty\_like(self);
int threads = 256;
int blocks = (self.numel() + threads - 1) / threads;
AT\_DISPATCH\_FLOATING\_TYPES(self.scalar\_type(), "my\_op\_cuda", \[&\] {
my\_op\_cuda\_kernel<scalar\_t><<<blocks, threads>>>(
output.data\_ptr<scalar\_t>(),
self.data\_ptr<scalar\_t>(),
self.numel()
);
});
return output;
}Register an operator
After you implement the kernels, register the operation in the native_functions.yaml file.
# aten/src/ATen/native/native\_functions.yaml
- func: my\_op(Tensor self) -> Tensor
variants: function, method
dispatch:
CPU: my\_op\_cpu
CUDA: my\_op\_cuda
CompositeImplicitAutograd: my\_op\_implBest practices
Use these strategies to maintain code quality and performance:
- To ensure high performance, use
TensorIteratorinstead of manual loops. - Respect strides and never assume memory is contiguous.
- Use
AT_DISPATCH_macros for correct type dispatching and reuse buffers to avoid unnecessary allocations. - Finally, test CPU, CUDA, and fallback paths, and profile your code with the PyTorch profiler before you optimize.
Debugging tips
Print tensor metadata to view internal values such as size, stride, and device: internals:
std::cout << "Tensor metadata:\n"
<< " sizes: " << self.sizes() << "\n"
<< " strides: " << self.strides() << "\n"
<< " dtype: " << self.dtype() << "\n"
<< " device: " << self.device() << "\n"
<< " data\_ptr: " << self.data\_ptr() << "\n";Check dispatch keys to verify routing:
std::cout << "Dispatch keys: " << self.key\_set() << "\n";Enable dispatcher logging with the following command:
export PYTORCH\_DISPATCHER\_LOG=1
python your\_script.pyConclusion
ATen represents an efficient approach tolibrary design. The library provides abstractions that maintain performance and an extensible architecture for new backends. Extensive code generation increases productivity, while the framework achieves performance levels comparable to hand-tuned libraries. Whether you're a PyTorch user wanting to understand what happens under the hood or a contributor adding new operations, understanding ATen is essential.
The next time you write result = (a * b).sum(), you'll know that ATen is orchestrating a complex dance that includes device detection, kernel dispatch, vectorized execution, memory management, and result construction—all within microseconds.
That's the power of ATen—and that's what makes PyTorch feel like magic.
Further reading
- PyTorch C++ API focumentation
- ATen sSource code
- native_functions.yaml reference
- Dispatcher RFC
- TensorIterator guide
Have questions or want to discuss ATen internals? Connect with me on LinkedIn.