Skip to main content
Redhat Developers  Logo
  • Products

    Platforms

    • Red Hat Enterprise Linux
      Red Hat Enterprise Linux Icon
    • Red Hat AI
      Red Hat AI
    • Red Hat OpenShift
      Openshift icon
    • Red Hat Ansible Automation Platform
      Ansible icon
    • See all Red Hat products

    Featured

    • Red Hat build of OpenJDK
    • Red Hat Developer Hub
    • Red Hat JBoss Enterprise Application Platform
    • Red Hat OpenShift Dev Spaces
    • Red Hat OpenShift Local
    • Red Hat Developer Sandbox

      Try Red Hat products and technologies without setup or configuration fees for 30 days with this shared Red Hat OpenShift and Kubernetes cluster.
    • Try at no cost
  • Technologies

    Featured

    • AI/ML
      AI/ML Icon
    • Linux
      Linux Icon
    • Kubernetes
      Cloud icon
    • Automation
      Automation Icon showing arrows moving in a circle around a gear
    • See all technologies
    • Programming languages & frameworks

      • Java
      • Python
      • JavaScript
    • System design & architecture

      • Red Hat architecture and design patterns
      • Microservices
      • Event-Driven Architecture
      • Databases
    • Developer experience

      • Productivity
      • Tools
      • GitOps
    • Automated data processing

      • AI/ML
      • Data science
      • Apache Kafka on Kubernetes
    • Platform engineering

      • DevOps
      • DevSecOps
      • Red Hat Ansible Automation Platform for applications and services
    • Secure development & architectures

      • Security
      • Secure coding
  • Learn

    Featured

    • Kubernetes & cloud native
      Openshift icon
    • Linux
      Rhel icon
    • Automation
      Ansible cloud icon
    • AI/ML
      AI/ML Icon
    • See all learning resources

    E-books

    • GitOps cookbook
    • Podman in action
    • Kubernetes operators
    • The path to GitOps
    • See all e-books

    Cheat sheets

    • Linux commands
    • Bash commands
    • Git
    • systemd commands
    • See all cheat sheets

    Documentation

    • Product documentation
    • API catalog
    • Legacy documentation
  • Developer Sandbox

    Developer Sandbox

    • Access Red Hat’s products and technologies without setup or configuration, and start developing quicker than ever before with our new, no-cost sandbox environments.
    • Explore the Developer Sandbox

    Featured Developer Sandbox activities

    • Get started with your Developer Sandbox
    • OpenShift virtualization and application modernization using the Developer Sandbox
    • Explore all Developer Sandbox activities

    Ready to start developing apps?

    • Try at no cost
  • Blog
  • Events
  • Videos

Advancing low‑bit quantization for LLMs: AutoRound x LLM Compressor

Achieve faster, more efficient LLM serving without sacrificing accuracy

December 9, 2025
Intel Neural Compressor Team, Red Hat AI Model Optimization Team
Related topics:
Artificial intelligence
Related products:
Red Hat AI

    AutoRound, a state‑of‑the‑art post‑training quantization (PTQ) algorithm developed by Intel, is now integrated into LLM Compressor. This collaboration delivers:

    • Higher accuracy for low bit-width quantization
    • Lightweight tuning (hundreds of steps instead of thousands)
    • Zero additional inference overhead
    • Seamless compatibility with compressed tensors and direct serving in vLLM
    • A streamlined workflow that lets you quantize and serve models with just a few lines of code

    Broader quantization schemes and model coverage are coming next—try it now and help shape what we build.

    What is AutoRound?

    AutoRound is an algorithm for reducing the size of large language models (LLMs) and vision-language models (VLMs) after training, called PTQ. It introduces three trainable parameters per quantized tensor: v (rounding offset/adjustment), α and β (learned clipping range controls). By processing decoder layers sequentially and applying signed gradient descent, AutoRound jointly optimizes rounding and clipping to minimize block‑wise output reconstruction error.

    Core strengths:

    • Superior accuracy, especially at very low bit‑widths
    • Supports multiple data types: W4A16, MXFP8, MXFP4, FP8, NVFP4, with more coming soon
    • Mixed‑bit, layer‑wise precision search for flexible accuracy–efficiency trade‑offs
    • Works with both large language models (LLMs) and vision-language models (VLMs)

    AutoRound creates quantized models in low-bit formats that accelerate inference on Intel Xeon processors, Intel Gaudi AI accelerators, Intel Data Center GPUs, Intel Arc B-Series Graphics, as well as other GPUs (for example, CUDA based devices).

    Looking forward, Intel is adding native support for FP8, MXFP8, and MXFP4 formats to its next-generation Intel Data Center GPU codenamed Crescent Island. Models quantized with AutoRound will naturally scale to take advantage of these data types across the Intel AI hardware portfolio. This creates a consistent path from algorithmic innovation to real-world deployment.

    For more details, refer to the paper AutoRound (EMNLP 2024) and the GitHub repository intel/auto-round.

    Why integrate into LLM Compressor?

    LLM Compressor already provides a unified, modular system for compression techniques such as quantization and pruning. Integrating AutoRound into this ecosystem:

    • Aligns with the existing modifier architecture (for example, GPTQModifier)
    • Reuses the sequential calibration and layer‑onloading infrastructure
    • Enables future interoperability with richer multi‑modifier recipes
    • Produces quantized models that are ready for vLLM serving, enabling a clean workflow from compression to deployment

    Integration overview

    We completed the first stage of integration by introducing the new AutoRoundModifier into LLM Compressor, enabling production of wNa16 (for example, W4A16) compressed models that load in vLLM, as implemented in PR #1994. With a straightforward configuration—just specify your model and calibration data—you can quickly generate high‑quality low‑bit checkpoints. This initial stage supports quantizing a range of dense LLMs, including the Llama and Qwen model families, and demonstrates robust compatibility for practical deployment.

    Try it now: Quick start

    This quick start walks you through the process from installation to evaluating the quantized model's performance.

    1. Install

    Start by cloning the repository and installing the necessary Python package.

    git clone https://github.com/vllm-project/llm-compressor.git
    cd llm-compressor
    pip install -e .

    2. Load model and tokenizer

    Load the model and tokenizer from the Hugging Face Model Hub, specifying the desired model ID.

    from transformers import AutoModelForCausalLM, AutoTokenizer
    MODEL_ID = "Qwen/Qwen3-8B"
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    3. Prepare calibration data

    Set up the calibration dataset, which is a small, unlabeled subset of data used to train the quantization parameters.

    from auto_round.calib_dataset import get_dataset
    NUM_CALIBRATION_SAMPLES = 128
    MAX_SEQUENCE_LENGTH = 2048
    ds = get_dataset(tokenizer=tokenizer,
                     seqlen=MAX_SEQUENCE_LENGTH,
                     nsamples=NUM_CALIBRATION_SAMPLES)

    4. Run quantization using AutoRound

    AutoRound quantization can run on a variety of devices, including CPUs and GPUs. Quantization and serving might not happen on the same device. For example, you can quantize on a workstation with GPU and later deploy on AIPC.

    from llmcompressor import oneshot
    from llmcompressor.modifiers.autoround import AutoRoundModifier
    recipe = AutoRoundModifier(
        targets="Linear",
        scheme="W4A16",
        ignore=["lm_head"],
        iters=200,
    )
    oneshot(
        model=model,
        dataset=ds,
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        shuffle_calibration_samples=False,
    )
    SAVE_DIR = MODEL_ID.split("/")[-1] + "-W4A16-G128-AutoRound"
    model.save_pretrained(SAVE_DIR, save_compressed=True)
    tokenizer.save_pretrained(SAVE_DIR)

    In practice, 128 calibration samples + ~200 iterations often reach stable convergence. Increase the number of samples or iterations if you are targeting extremely low bits or tighter accuracy targets.

    5. Serve in vLLM

    Once quantization is complete, the same compressed model can be served on different hardware, independent of the device used for tuning. For example, you can serve the quantized Qwen3‑8B‑W4A16‑G128‑AutoRound model on a single Intel Arc Pro B60 GPU:

    vllm serve Qwen3-8B-W4A16-G128-AutoRound \
        --dtype=bfloat16 \
        --gpu-memory-utilization 0.8 \
        --max-num-batched-tokens 8192

    Note

    Install vLLM from PR 29484 to serve this model. When serving on XPU, you must run vLLM with the --enforce-eager flag.

    6. Evaluate (Example: GSM8K with lm_eval)

    Finally, you can evaluate the quantized model's performance on a benchmark dataset using the command-line interface utility.

    lm_eval --model vllm \
      --model_args pretrained="./Qwen3-8B-W4A16-G128-AutoRound,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enforce_eager=True" \
      --tasks gsm8k \
      --num_fewshot 5 \
      --limit 1000 \
      --seed 42 \
      --batch_size 128
    |Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
    |-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
    |gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.911|±  | 0.009|
    |     |       |strict-match    |     5|exact_match|↑  |0.911|±  | 0.009|

    Conclusion and future plans

    With this first integration, AutoRound and LLM Compressor already provide a practical, production‑oriented path to low‑bit LLMs: W4A16 quantization is supported end‑to‑end, the workflow is simple to configure, and dense models such as Llama and Qwen are supported. The setup is robust, streamlined, and ready for practical deployment.

    Looking ahead, we plan to extend support to additional schemes such as FP8, MXFP4, MXFP8, and NVFP4, add automatic mixed‑bit search for fine‑grained per‑layer optimization, and cover more model families, including Mixture‑of‑Experts (MoE) models. We also aim to deepen interoperability with other algorithms in LLM Compressor, which will allow AutoRound to be combined into richer multi‑modifier recipes that serve both community use cases and Intel production workloads.

    If you’d like to influence which formats, models, and workflows we prioritize next, join the discussion in RFC #1968 and share your benchmarks or deployment requirements, or bring your feedback to the Intel community so we can align the roadmap with real‑world needs.

    Acknowledgements

    We wish to acknowledge the contributions of the LLM Compressor community. Specifically, we thank Kyle Sayers, Dipika Sikka, Brian Dellabetta, Charles Hernandez, and Robert Shaw for their invaluable feedback on the early proposal and their diligent review of the pull requests.

    Related RFCs and PRs

    • llm-compressor#1968
    • llm-compressor#1994
    • llm-compressor#2055
    • llm-compressor#2062
    • auto-round#993
    • auto-round#1053
    • auto-round#1055
    • auto-round#1072
    • vllm#29484

    Recent Posts

    • How to deploy and benchmark vLLM with GuideLLM on Kubernetes

    • Getting started with OpenShift APIs for Data Protection

    • How in-place pod resizing boosts efficiency in OpenShift

    • Automate Oracle 19c deployments on OpenShift Virtualization

    • Monitoring OpenShift Gateway API and Service Mesh with Kiali

    Red Hat Developers logo LinkedIn YouTube Twitter Facebook

    Platforms

    • Red Hat AI
    • Red Hat Enterprise Linux
    • Red Hat OpenShift
    • Red Hat Ansible Automation Platform
    • See all products

    Build

    • Developer Sandbox
    • Developer tools
    • Interactive tutorials
    • API catalog

    Quicklinks

    • Learning resources
    • E-books
    • Cheat sheets
    • Blog
    • Events
    • Newsletter

    Communicate

    • About us
    • Contact sales
    • Find a partner
    • Report a website issue
    • Site status dashboard
    • Report a security problem

    RED HAT DEVELOPER

    Build here. Go anywhere.

    We serve the builders. The problem solvers who create careers with code.

    Join us if you’re a developer, software engineer, web designer, front-end designer, UX designer, computer scientist, architect, tester, product manager, project manager or team lead.

    Sign me up

    Red Hat legal and privacy links

    • About Red Hat
    • Jobs
    • Events
    • Locations
    • Contact Red Hat
    • Red Hat Blog
    • Inclusion at Red Hat
    • Cool Stuff Store
    • Red Hat Summit
    © 2025 Red Hat

    Red Hat legal and privacy links

    • Privacy statement
    • Terms of use
    • All policies and guidelines
    • Digital accessibility

    Report a website issue