AutoRound, a state‑of‑the‑art post‑training quantization (PTQ) algorithm developed by Intel, is now integrated into LLM Compressor. This collaboration delivers:
- Higher accuracy for low bit-width quantization
- Lightweight tuning (hundreds of steps instead of thousands)
- Zero additional inference overhead
- Seamless compatibility with compressed tensors and direct serving in vLLM
- A streamlined workflow that lets you quantize and serve models with just a few lines of code
Broader quantization schemes and model coverage are coming next—try it now and help shape what we build.
What is AutoRound?
AutoRound is an algorithm for reducing the size of large language models (LLMs) and vision-language models (VLMs) after training, called PTQ. It introduces three trainable parameters per quantized tensor: v (rounding offset/adjustment), α and β (learned clipping range controls). By processing decoder layers sequentially and applying signed gradient descent, AutoRound jointly optimizes rounding and clipping to minimize block‑wise output reconstruction error.
Core strengths:
- Superior accuracy, especially at very low bit‑widths
- Supports multiple data types: W4A16, MXFP8, MXFP4, FP8, NVFP4, with more coming soon
- Mixed‑bit, layer‑wise precision search for flexible accuracy–efficiency trade‑offs
- Works with both large language models (LLMs) and vision-language models (VLMs)
AutoRound creates quantized models in low-bit formats that accelerate inference on Intel Xeon processors, Intel Gaudi AI accelerators, Intel Data Center GPUs, Intel Arc B-Series Graphics, as well as other GPUs (for example, CUDA based devices).
Looking forward, Intel is adding native support for FP8, MXFP8, and MXFP4 formats to its next-generation Intel Data Center GPU codenamed Crescent Island. Models quantized with AutoRound will naturally scale to take advantage of these data types across the Intel AI hardware portfolio. This creates a consistent path from algorithmic innovation to real-world deployment.
For more details, refer to the paper AutoRound (EMNLP 2024) and the GitHub repository intel/auto-round.
Why integrate into LLM Compressor?
LLM Compressor already provides a unified, modular system for compression techniques such as quantization and pruning. Integrating AutoRound into this ecosystem:
- Aligns with the existing modifier architecture (for example, GPTQModifier)
- Reuses the sequential calibration and layer‑onloading infrastructure
- Enables future interoperability with richer multi‑modifier recipes
- Produces quantized models that are ready for vLLM serving, enabling a clean workflow from compression to deployment
Integration overview
We completed the first stage of integration by introducing the new AutoRoundModifier into LLM Compressor, enabling production of wNa16 (for example, W4A16) compressed models that load in vLLM, as implemented in PR #1994. With a straightforward configuration—just specify your model and calibration data—you can quickly generate high‑quality low‑bit checkpoints. This initial stage supports quantizing a range of dense LLMs, including the Llama and Qwen model families, and demonstrates robust compatibility for practical deployment.
Try it now: Quick start
This quick start walks you through the process from installation to evaluating the quantized model's performance.
1. Install
Start by cloning the repository and installing the necessary Python package.
git clone https://github.com/vllm-project/llm-compressor.git
cd llm-compressor
pip install -e .2. Load model and tokenizer
Load the model and tokenizer from the Hugging Face Model Hub, specifying the desired model ID.
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "Qwen/Qwen3-8B"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)3. Prepare calibration data
Set up the calibration dataset, which is a small, unlabeled subset of data used to train the quantization parameters.
from auto_round.calib_dataset import get_dataset
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
ds = get_dataset(tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES)4. Run quantization using AutoRound
AutoRound quantization can run on a variety of devices, including CPUs and GPUs. Quantization and serving might not happen on the same device. For example, you can quantize on a workstation with GPU and later deploy on AIPC.
from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
recipe = AutoRoundModifier(
targets="Linear",
scheme="W4A16",
ignore=["lm_head"],
iters=200,
)
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
shuffle_calibration_samples=False,
)
SAVE_DIR = MODEL_ID.split("/")[-1] + "-W4A16-G128-AutoRound"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)In practice, 128 calibration samples + ~200 iterations often reach stable convergence. Increase the number of samples or iterations if you are targeting extremely low bits or tighter accuracy targets.
5. Serve in vLLM
Once quantization is complete, the same compressed model can be served on different hardware, independent of the device used for tuning. For example, you can serve the quantized Qwen3‑8B‑W4A16‑G128‑AutoRound model on a single Intel Arc Pro B60 GPU:
vllm serve Qwen3-8B-W4A16-G128-AutoRound \
--dtype=bfloat16 \
--gpu-memory-utilization 0.8 \
--max-num-batched-tokens 8192Note
Install vLLM from PR 29484 to serve this model. When serving on XPU, you must run vLLM with the --enforce-eager flag.
6. Evaluate (Example: GSM8K with lm_eval)
Finally, you can evaluate the quantized model's performance on a benchmark dataset using the command-line interface utility.
lm_eval --model vllm \
--model_args pretrained="./Qwen3-8B-W4A16-G128-AutoRound,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enforce_eager=True" \
--tasks gsm8k \
--num_fewshot 5 \
--limit 1000 \
--seed 42 \
--batch_size 128|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.911|± | 0.009|
| | |strict-match | 5|exact_match|↑ |0.911|± | 0.009|Conclusion and future plans
With this first integration, AutoRound and LLM Compressor already provide a practical, production‑oriented path to low‑bit LLMs: W4A16 quantization is supported end‑to‑end, the workflow is simple to configure, and dense models such as Llama and Qwen are supported. The setup is robust, streamlined, and ready for practical deployment.
Looking ahead, we plan to extend support to additional schemes such as FP8, MXFP4, MXFP8, and NVFP4, add automatic mixed‑bit search for fine‑grained per‑layer optimization, and cover more model families, including Mixture‑of‑Experts (MoE) models. We also aim to deepen interoperability with other algorithms in LLM Compressor, which will allow AutoRound to be combined into richer multi‑modifier recipes that serve both community use cases and Intel production workloads.
If you’d like to influence which formats, models, and workflows we prioritize next, join the discussion in RFC #1968 and share your benchmarks or deployment requirements, or bring your feedback to the Intel community so we can align the roadmap with real‑world needs.
Acknowledgements
We wish to acknowledge the contributions of the LLM Compressor community. Specifically, we thank Kyle Sayers, Dipika Sikka, Brian Dellabetta, Charles Hernandez, and Robert Shaw for their invaluable feedback on the early proposal and their diligent review of the pull requests.