Skip to main content
Redhat Developers  Logo
  • Products

    Platforms

    • Red Hat Enterprise Linux
      Red Hat Enterprise Linux Icon
    • Red Hat AI
      Red Hat AI
    • Red Hat OpenShift
      Openshift icon
    • Red Hat Ansible Automation Platform
      Ansible icon
    • View All Red Hat Products

    Featured

    • Red Hat build of OpenJDK
    • Red Hat Developer Hub
    • Red Hat JBoss Enterprise Application Platform
    • Red Hat OpenShift Dev Spaces
    • Red Hat OpenShift Local
    • Red Hat Developer Sandbox

      Try Red Hat products and technologies without setup or configuration fees for 30 days with this shared Openshift and Kubernetes cluster.
    • Try at no cost
  • Technologies

    Featured

    • AI/ML
      AI/ML Icon
    • Linux
      Linux Icon
    • Kubernetes
      Cloud icon
    • Automation
      Automation Icon showing arrows moving in a circle around a gear
    • View All Technologies
    • Programming Languages & Frameworks

      • Java
      • Python
      • JavaScript
    • System Design & Architecture

      • Red Hat architecture and design patterns
      • Microservices
      • Event-Driven Architecture
      • Databases
    • Developer Productivity

      • Developer productivity
      • Developer Tools
      • GitOps
    • Automated Data Processing

      • AI/ML
      • Data Science
      • Apache Kafka on Kubernetes
    • Platform Engineering

      • DevOps
      • DevSecOps
      • Ansible automation for applications and services
    • Secure Development & Architectures

      • Security
      • Secure coding
  • Learn

    Featured

    • Kubernetes & Cloud Native
      Openshift icon
    • Linux
      Rhel icon
    • Automation
      Ansible cloud icon
    • AI/ML
      AI/ML Icon
    • View All Learning Resources

    E-Books

    • GitOps Cookbook
    • Podman in Action
    • Kubernetes Operators
    • The Path to GitOps
    • View All E-books

    Cheat Sheets

    • Linux Commands
    • Bash Commands
    • Git
    • systemd Commands
    • View All Cheat Sheets

    Documentation

    • Product Documentation
    • API Catalog
    • Legacy Documentation
  • Developer Sandbox

    Developer Sandbox

    • Access Red Hat’s products and technologies without setup or configuration, and start developing quicker than ever before with our new, no-cost sandbox environments.
    • Explore Developer Sandbox

    Featured Developer Sandbox activities

    • Get started with your Developer Sandbox
    • OpenShift virtualization and application modernization using the Developer Sandbox
    • Explore all Developer Sandbox activities

    Ready to start developing apps?

    • Try at no cost
  • Blog
  • Events
  • Videos

LLM Compressor 0.7.0 release recap

Featuring new transforms, mixed-precision, and block quantization

August 25, 2025
Dipika Sikka Kyle Sayers Brian Dellabetta Helen Zhao
Related topics:
Artificial intelligence
Related products:
Red Hat AI

Share:

    LLM Compressor has recently released version 0.7.0, which introduces a range of significant enhancements designed to improve the performance of quantizing and deploying large language models. This release features three notable additions:

    • QuIP and SpinQuant-style transforms
    • Mixed-precision support and FP4 enhancements
    • DeepSeek v3-style block quantization

    1. QuIP and SpinQuant-style transforms

    This release introduces two new modifiers, QuIPModifier and SpinQuantModifier.

    These modifiers facilitate the injection of Hadamard-based rotations into the model's computational graph, thereby rotating weights and activations to mitigate quantization sensitivity. Applying these transforms can minimize the effect of quantization error and enhance accuracy, particularly in cases involving low-bit weight and activation quantization.

    Rotating the weight space helps even out outliers which can improve the fidelity of post-training quantization. In order to accomplish this, QuIP rotates inputs into a rotated space, applies quantization, then rotates those outputs back into the original output space in order to preserve correctness.

    Example: Using QuIPModifier for QuIP-style transforms

    from transformers import AutoModelForCausalLM
    
    from llmcompressor import oneshot
    from llmcompressor.modifiers.quantization import QuantizationModifier
    from llmcompressor.modifiers.transform import QuIPModifier
    
    # Select model and load it.
    MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
    
    # Configure the quantization algorithm to run.
    #   * apply quip-style transforms to model in order to make quantization easier
    #   * quantize the weights to 4 bit with a group size 128
    recipe = [
        QuIPModifier(transform_type="random-hadamard", targets="Linear"),
        QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
    ]
    
    # Apply algorithms.
    oneshot(model=model, recipe=recipe, pipeline="datafree")

    For a sample model produced from the QuIPModifier example above, you can see where and how the rotations are applied to the model by looking at the transform_config in the model's config.json.

    "u": {
      "apply": [
        {
          "ignore": [
            "lm_head"
          ],
          "inverse": false,
          "location": "weight_output",
          "targets": [
            "Linear"
          ]
        },
        {
          "ignore": [
            "lm_head"
          ],
          "inverse": true,
          "location": "output",
          "targets": [
            "Linear"
          ]
        }
      ]
    }

    In this case, the Hadamard transform (denoted by the config_group u) is applied to the layer weights, and the inverse matrix is applied to each of the layer's outputs.

    Example: Using SpinQuantModifier

    SpinQuant and QuaRot build upon the ideas of QuIP and apply rotations which span across activations, allowing for more efficient weight and activation quantization. In addition, many of the added rotations are considered to be "offline" rotations (known as R1 and R2), meaning that these rotations are fused directly into the model's weights prior to quantization, allowing for rotation without additional runtime cost. See Figure 1.

    Figure 1
    Figure 1: Note that as of now, only R1 and R2 rotations are available. R3 and R4 rotations will be available in a future release.
    from transformers import AutoModelForCausalLM
    from llmcompressor import oneshot
    from llmcompressor.modifiers.quantization import QuantizationModifier
    from llmcompressor.modifiers.transform import SpinQuantModifier
    MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
    # Configure the quantization algorithm to run.
    #   * apply spinquant transforms to model to reduce quantization loss
    #   * quantize the weights to 4 bit with group size 128
    recipe = [
        SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
        QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
    ]
    # Apply algorithms.
    oneshot(model=model, recipe=recipe, pipeline="datafree")

    A similar transform_config is created for models produced using the SpinQuantModifier.

    2. Mixed-precision support and FP4 enhancements

    LLM Compressor v0.7.0 also brings robust mixed-precision capabilities. FP4 quantization (specifically NVFP4) for both weights and activations has now been integrated with MoEs (such as Llama 4) and non-uniform quantization schemes.

    With non-uniform quantization, you can combine NVFP4 and FP8 quantization, selectively applying certain quantization schemes to specific layers for improved accuracy. This functionality is enabled by the activation of multiple compressors within a given model.

    Example: Non-uniform quantization

    From a sample model with both NVFP4 and FP8 quantization (FP8 targeting down_proj weights, NVFP4 targeting all other attention and MLP linear layer weights), the quantization_config of the compressed model looks like this:

    "quantization_config": {
       "config_groups": {
         "group_0": {
           "format": "nvfp4-pack-quantized",
           "input_activations": {
             "actorder": null,
             "block_structure": null,
             "dynamic": "local",
             "group_size": 16,
             "num_bits": 4,
             "observer": "minmax",
             "observer_kwargs": {},
             "strategy": "tensor_group",
             "symmetric": true,
             "type": "float"
           },
           "output_activations": null,
           "targets": [
             "re:.*mlp.gate_proj.*",
             "re:.*mlp.up_proj.*",
             "re:.*self_attn.k_proj.*",
             "re:.*self_attn.o_proj.*",
             "re:.*self_attn.q_proj.*",
             "re:.*self_attn.v_proj.*"
           ],
           "weights": {
             "actorder": null,
             "block_structure": null,
             "dynamic": false,
             "group_size": 16,
             "num_bits": 4,
             "observer": "minmax",
             "observer_kwargs": {},
             "strategy": "tensor_group",
             "symmetric": true,
             "type": "float"
           }
         },
         "group_1": {
           "format": "float-quantized",
           "input_activations": {
             "actorder": null,
             "block_structure": null,
             "dynamic": true,
             "group_size": null,
             "num_bits": 8,
             "observer": null,
             "observer_kwargs": {},
             "strategy": "token",
             "symmetric": true,
             "type": "float"
           },
           "output_activations": null,
           "targets": [
             "re:.*mlp.down_proj.*"
           ],
           "weights": {
             "actorder": null,
             "block_structure": null,
             "dynamic": false,
             "group_size": null,
             "num_bits": 8,
             "observer": "minmax",
             "observer_kwargs": {},
             "strategy": "channel",
             "symmetric": true,
             "type": "float"
           }
         }
       },
       "format": "mixed-precision"

    This configuration shows how you can assign different quantization formats (e.g., nvfp-pack-quantized and float-quantized, each handled by a separate compressor in compressed-tensors) per layer group, mixing NVFP4 with FP8. This provides finer control over per-layer quantization, allowing more precise handling of layers that are especially sensitive to certain quantization types.

    As of v0.10.1, models with multiple compressors are directly runnable in vLLM. We can run sample evaluations using the lm-evaluation-harness with the above mixed-precision model, comparing it with its NVFP4-only counterpart.

    Using the following lm-eval command on a single B200 GPU for each model, we get the following results:

    lm_eval \
      --model vllm \
      --model_args pretrained=model_path,dtype=auto,max_model_len=4096,tensor_parallel_size=1,enable_chunked_prefill=True,enforce_eager=True \
      --tasks gsm8k_llama \
      --apply_chat_template \
      --fewshot_as_multiturn \
      --batch_size auto

    NVFP4 only:

    |   Tasks   |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
    |-----------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
    |gsm8k_llama|      3|flexible_extract|     8|exact_match|↑  |0.7278|±  |0.0123|
    |           |       |strict_match    |     8|exact_match|↑  |0.6285|±  |0.0133|

    NVFP4 with FP8 down_proj:

    |   Tasks   |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
    |-----------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
    |gsm8k_llama|      3|flexible_extract|     8|exact_match|↑  |0.7536|±  |0.0119|
    |           |       |strict_match    |     8|exact_match|↑  |0.6914|±  |0.0127|

    3. DeepSeek v3-style block quantization

    Another notable addition is block-wise quantization inspired by DeepSeek v3. This method enables more efficient model compression without needing a calibration dataset. Block quantization partitions weights into blocks and quantizes each independently, minimizing the influence of outliers while preserving accuracy.

    Example: Specify a recipe with FP8 block quantization for Qwen/Qwen3-30B-A3B

    from llmcompressor.modifiers.quantization import QuantizationModifier
    recipe = QuantizationModifier(
        targets="Linear",
        scheme="FP8_BLOCK",
        ignore=["lm_head", "re:.*mlp.gate$"],
    )

    Summary of key new features in LLM Compressor 0.7.0

    • Transforms: QuIPModifier and SpinQuantModifier—Hadamard rotations to reduce quantization error
    • Mixed-precision support: FP4 quantization with MoE and non-uniform support
    • Block quantization: DeepSeek v3-style block-wise quantization for calibration-free, efficient compression.

    Conclusion

    LLM Compressor bridges the gap between fine-tuning and production with robust support for quantization, sparsity, calibration, and seamless integration with vLLM. Whether you're optimizing for cost, latency, or innovation, LLM Compressor is the foundation for next-generation AI inference.

    To get started, explore our Hugging Face collection, contribute to the GitHub repo, and join the community of developers and researchers advancing efficient AI.

    Related Posts

    • LLM Compressor is here: Faster inference with vLLM

    • Multimodal model quantization support through LLM Compressor

    • LLM Compressor: Optimize LLMs for low-latency deployments

    • Axolotl meets LLM Compressor: Fast, sparse, open

    • Compressed Granite 3.1: Powerful performance in a small package

    • Optimize LLMs with LLM Compressor in Red Hat OpenShift AI

    Recent Posts

    • Cloud bursting with confidential containers on OpenShift

    • Reach native speed with MacOS llama.cpp container inference

    • A deep dive into Apache Kafka's KRaft protocol

    • Staying ahead of artificial intelligence threats

    • Strengthen privacy and security with encrypted DNS in RHEL

    What’s up next?

    Configure your RHEL AI machine, download, serve, and interact with large language models using RHEL AI and InstructLab, and discover how developers can benefit from AI models tailored to their needs.

    Start the activity
    Red Hat Developers logo LinkedIn YouTube Twitter Facebook

    Products

    • Red Hat Enterprise Linux
    • Red Hat OpenShift
    • Red Hat Ansible Automation Platform

    Build

    • Developer Sandbox
    • Developer Tools
    • Interactive Tutorials
    • API Catalog

    Quicklinks

    • Learning Resources
    • E-books
    • Cheat Sheets
    • Blog
    • Events
    • Newsletter

    Communicate

    • About us
    • Contact sales
    • Find a partner
    • Report a website issue
    • Site Status Dashboard
    • Report a security problem

    RED HAT DEVELOPER

    Build here. Go anywhere.

    We serve the builders. The problem solvers who create careers with code.

    Join us if you’re a developer, software engineer, web designer, front-end designer, UX designer, computer scientist, architect, tester, product manager, project manager or team lead.

    Sign me up

    Red Hat legal and privacy links

    • About Red Hat
    • Jobs
    • Events
    • Locations
    • Contact Red Hat
    • Red Hat Blog
    • Inclusion at Red Hat
    • Cool Stuff Store
    • Red Hat Summit
    © 2025 Red Hat

    Red Hat legal and privacy links

    • Privacy statement
    • Terms of use
    • All policies and guidelines
    • Digital accessibility

    Report a website issue