Skip to main content
Redhat Developers  Logo
  • Products

    Featured

    • Red Hat Enterprise Linux
      Red Hat Enterprise Linux Icon
    • Red Hat OpenShift AI
      Red Hat OpenShift AI
    • Red Hat Enterprise Linux AI
      Linux icon inside of a brain
    • Image mode for Red Hat Enterprise Linux
      RHEL image mode
    • Red Hat OpenShift
      Openshift icon
    • Red Hat Ansible Automation Platform
      Ansible icon
    • Red Hat Developer Hub
      Developer Hub
    • View All Red Hat Products
    • Linux

      • Red Hat Enterprise Linux
      • Image mode for Red Hat Enterprise Linux
      • Red Hat Universal Base Images (UBI)
    • Java runtimes & frameworks

      • JBoss Enterprise Application Platform
      • Red Hat build of OpenJDK
    • Kubernetes

      • Red Hat OpenShift
      • Microsoft Azure Red Hat OpenShift
      • Red Hat OpenShift Virtualization
      • Red Hat OpenShift Lightspeed
    • Integration & App Connectivity

      • Red Hat Build of Apache Camel
      • Red Hat Service Interconnect
      • Red Hat Connectivity Link
    • AI/ML

      • Red Hat OpenShift AI
      • Red Hat Enterprise Linux AI
    • Automation

      • Red Hat Ansible Automation Platform
      • Red Hat Ansible Lightspeed
    • Developer tools

      • Red Hat Trusted Software Supply Chain
      • Podman Desktop
      • Red Hat OpenShift Dev Spaces
    • Developer Sandbox

      Developer Sandbox
      Try Red Hat products and technologies without setup or configuration fees for 30 days with this shared Openshift and Kubernetes cluster.
    • Try at no cost
  • Technologies

    Featured

    • AI/ML
      AI/ML Icon
    • Linux
      Linux Icon
    • Kubernetes
      Cloud icon
    • Automation
      Automation Icon showing arrows moving in a circle around a gear
    • View All Technologies
    • Programming Languages & Frameworks

      • Java
      • Python
      • JavaScript
    • System Design & Architecture

      • Red Hat architecture and design patterns
      • Microservices
      • Event-Driven Architecture
      • Databases
    • Developer Productivity

      • Developer productivity
      • Developer Tools
      • GitOps
    • Secure Development & Architectures

      • Security
      • Secure coding
    • Platform Engineering

      • DevOps
      • DevSecOps
      • Ansible automation for applications and services
    • Automated Data Processing

      • AI/ML
      • Data Science
      • Apache Kafka on Kubernetes
      • View All Technologies
    • Start exploring in the Developer Sandbox for free

      sandbox graphic
      Try Red Hat's products and technologies without setup or configuration.
    • Try at no cost
  • Learn

    Featured

    • Kubernetes & Cloud Native
      Openshift icon
    • Linux
      Rhel icon
    • Automation
      Ansible cloud icon
    • Java
      Java icon
    • AI/ML
      AI/ML Icon
    • View All Learning Resources

    E-Books

    • GitOps Cookbook
    • Podman in Action
    • Kubernetes Operators
    • The Path to GitOps
    • View All E-books

    Cheat Sheets

    • Linux Commands
    • Bash Commands
    • Git
    • systemd Commands
    • View All Cheat Sheets

    Documentation

    • API Catalog
    • Product Documentation
    • Legacy Documentation
    • Red Hat Learning

      Learning image
      Boost your technical skills to expert-level with the help of interactive lessons offered by various Red Hat Learning programs.
    • Explore Red Hat Learning
  • Developer Sandbox

    Developer Sandbox

    • Access Red Hat’s products and technologies without setup or configuration, and start developing quicker than ever before with our new, no-cost sandbox environments.
    • Explore Developer Sandbox

    Featured Developer Sandbox activities

    • Get started with your Developer Sandbox
    • OpenShift virtualization and application modernization using the Developer Sandbox
    • Explore all Developer Sandbox activities

    Ready to start developing apps?

    • Try at no cost
  • Blog
  • Events
  • Videos

Multimodal model quantization support through LLM Compressor

February 19, 2025
Kyle Sayers Dipika Sikka Shubhra Pandit Mark Kurtz
Related topics:
Artificial intelligenceOpen source
Related products:
Red Hat AI

Share:

    A compressed summary

    • LLM Compressor version 0.4.0 supports multimodal model quantization, enabling efficient compression of vision-language and audio models with the most popular quantization formats.
    • GPTQ, our most popular algorithm, is fully extended and tested with complex multi-modal architectures, including Whisper and Llama 3.2 Vision.
    • Examples and evaluations confirm the expected high recoverability, with >99% across some quick samples while reducing memory and compute requirements.
    • This solution provides seamless integration with vLLM, powering a faster, scalable, and more cost-effective approach for real-world deployments.

    LLM Compressor is a unified library for optimizing models for deployment with vLLM. As of its 0.4.0 release, LLM Compressor now supports multimodal model quantization, enabling efficient compression of vision-language and audio models with the most popular quantization formats. 

    Read on to explore these enhancements, along with step-by-step examples that demonstrate how to use LLM Compressor to apply GPTQ quantization to your own models.

    Productized model compression

    LLM Compressor is an open source library that productizes the latest research in model compression, enabling easy generation of compressed models with minimal effort. The LLM Compressor framework allows users to apply state-of-the-art research across quantization, sparsity, and general compression techniques to improve generative AI models' efficiency, scalability, and performance while maintaining accuracy. With native Hugging Face and vLLM support, optimized models can seamlessly integrate with deployment pipelines for faster, cost-saving inference at scale, powered by the compressed-tensors model format.

    Designed for flexibility, LLM Compressor supports both post-training and training workflows for compression through Modifiers, implementations that apply a specific compression method to a given model. Modifier implementations cover a wide range of compression algorithms and techniques, including:

    • Weight-only quantization (W4A16) for limited hardware or latency-sensitive applications.
    • Weight and activation quantization (W8A8) targeting general server scenarios for both integer and floating point formats.
    • 2:4 semi-structured sparsity for further inference acceleration.

    With the 0.4.0 release, LLM Compressor adds general support for multimodal models, including vision and audio, and extends GPTQ-based quantization for performant support. The following sections explore these enhancements, their usage, and examples to quantize your own models.

    Multimodal enablement

    LLM Compressor and the GPTQModifier have been expanded to accommodate performant multimodal model compression, enabling SOTA quantization for vision and audio models while maintaining accuracy. This enhancement allows architectures like Whisper and Llama 3.2 Vision to benefit from quantization, making them more efficient for deployment with vLLM.

    The GPTQ algorithm, as described in the original paper, applies quantization sequentially to each model layer, using the quantized outputs of the previous layer as inputs to the next. This approach propagates and compensates for quantization-induced errors, improving accuracy recovery while minimizing memory usage – particularly important as each layer requires a large Hessian matrix to calculate and adjust for errors. While this process is trivial for most decoder-only transformer architectures, identifying the layers and data flow for more complex, multimodal architectures requires a generalized and flexible approach. For example, Whisper’s audio encoder feeds features into each text decoder layer; this data passing must be accounted for to faithfully calibrate the model while minimizing the number of resources to do so.

    To address this, the GPTQModifier now integrates tracing, a technique that records a model’s execution to capture its computational graph, which can then be partitioned into layers. This enables the calibration and quantization of layers sequentially belonging to arbitrary model architectures, such as vision-language, audio, and other multimodal models. By applying quantization in a structured, automated way, LLM Compressor simplifies the process of complicated research flows into a productized framework for both enterprise and developer use cases.

    While tracing works for most models and datasets out of the box, some may require minor adjustments to ensure compatibility. If you encounter issues, refer to the model tracing guide for tips on modifying your model definition.

    Validated accuracy

    With the latest enhancements to LLM Compressor, several multimodal models were quantized and evaluated across core benchmarks to assess performance and accuracy retention. Llama 3.2 11B and 90B Vision models were evaluated using mistral-evals on the MMMU task with vLLM, demonstrating >99% accuracy recovery as seen in Table 1. 

    Table 1: MMMU evaluation results comparing dense and quantized versions of Llama 3.2 Vision models.
     

    Baseline (BF16)

    W4A16, per-channel quantization

    W4A16, group-size 128 quantization

    Model

    MMMU

    MMMU

    Recovery

    MMMU

    Recovery

    Llama 3.2 11B Vision

    41.4

    43.8

    105.6%

    42.1

    101.6%

    Llama 3.2 90B Vision

    53.9

    51.1

    94.9%

    54.8

    101.7%

    Similarly, Whisper Large V2 was quantized and evaluated on a sample from the LibriSpeech dataset using Word Error Rate (WER). As shown in Table 2, the compressed version maintains >99% recovery while significantly reducing the memory requirements.

    Table 2: LibriSpeech evaluation results for Whisper Large V2, comparing dense and quantized versions.
     

    Baseline (BF16)

    W4A16, group-size 128 quantization

    Model

    LibriSpeech WER

    LibriSpeech WER

    Recovery

    Whisper Large V2

    87.4

    86.5

    99.0%

    Hands-on quantization

    In the following sections, we will review some step-by-step examples of how to apply GPTQ quantization to your own models using LLM Compressor. These examples demonstrate real-world applications of multimodal compression, covering vision-language models (Llama 3.2 Vision) and audio models (Whisper Large V2). You can find a complete list of other available examples in the LLM Compressor examples folder. Additionally, for more examples of running multi-modal models with vLLM, see the provided offline inference examples.

    Environment enablement

    Before running any of the following sections, ensure you have installed LLM Compressor from PyPi on a compatible environment.

    pip install llmcompressor>=0.4.0

    Quantizing vision language models

    We will use the Llama3.2 vision model to demonstrate the support of multimodal vision architecture.

    First, load the model. The Llama3.2 vision model architecture requires loading from a custom TraceableMllamaForConditionalGeneration class, which makes minor modifications to the original class definition to support tracing with the GPTQModifier.

    import requests
    import torch
    from PIL import Image
    from transformers import AutoProcessor
    from llmcompressor.modifiers.quantization import GPTQModifier
    from llmcompressor.transformers import oneshot
    from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration
    # Load model.
    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    model = TraceableMllamaForConditionalGeneration.from_pretrained(
        model_id, device_map="auto", torch_dtype="auto"
    )
    processor = AutoProcessor.from_pretrained(model_id)

    Next, define your calibration dataset and data collator. For this example, we will use the flickr30k dataset, which contains many scenes and images of objects. You can customize the calibration dataset to reflect your use case.

    # Oneshot arguments
    DATASET_ID = "flickr30k"
    DATASET_SPLIT = {"calibration": "test[:512]"}
    NUM_CALIBRATION_SAMPLES = 512
    MAX_SEQUENCE_LENGTH = 2048
    # Define a oneshot data collator for multimodal inputs
    def data_collator(batch):
        assert len(batch) == 1
        return {key: torch.tensor(value) for key, value in batch[0].items()}

    Now you can apply one-shot recipe to quantize your model. In this case, we use GPTQ to apply the weight and activation quantization, as shown in the following recipe. Due to their small size and limited support for quantized acceleration, we ignore the vision model parameters in our recipe.

    # Recipe
    recipe = GPTQModifier(
        targets="Linear",
        scheme="W4A16",
        ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"],
    )
    # Perform oneshot
    oneshot(
        model=model,
        tokenizer=model_id,
        dataset=DATASET_ID,
        splits=DATASET_SPLIT,
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        trust_remote_code_model=True,
        data_collator=data_collator,
    )
    # Save to disk compressed.
    SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
    model.save_pretrained(SAVE_DIR, save_compressed=True)
    processor.save_pretrained(SAVE_DIR)

    Finally, you can now deploy the model with vLLM for better inference performance:

    from transformers import AutoProcessor
    from vllm.assets.image import ImageAsset
    from vllm import LLM, SamplingParams
    # prepare model
    model_id = "Llama-3.2-11B-Vision-Instruct-quantized.w4a16"
    llm = LLM(
        model=model_id,
        max_model_len=4096,
        max_num_seqs=16,
        limit_mm_per_prompt={"image": 1},
    )
    processor = AutoProcessor.from_pretrained(model_id)
    # prepare inputs
    question = "What is the content of this image?"
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": f"{question}"},
            ],
        },
    ]
    prompt = processor.apply_chat_template(
        messages, add_generation_prompt=True,tokenize=False
    )
    image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
    inputs = {
        "prompt": prompt,
        "multi_modal_data": {
            "image": image
        },
    }
    # generate response
    print("========== SAMPLE GENERATION ==============")
    outputs = llm.generate(inputs, SamplingParams(temperature=0.2, max_tokens=64))
    print(f"PROMPT  : {outputs[0].prompt}")
    print(f"RESPONSE: {outputs[0].outputs[0].text}")
    print("==========================================")

    Quantizing audio models

    We will use the Whisper Large V2 model to demonstrate multimodal audio architecture support.

    First, load the model. The whisper architecture requires loading from a custom TraceableWhisperForConditionalGeneration class, which makes minor modifications to the original class definition to support tracing with the GPTQModifier.

    import torch
    from datasets import load_dataset
    from transformers import WhisperProcessor
    from llmcompressor.modifiers.quantization import GPTQModifier
    from llmcompressor.transformers import oneshot
    from llmcompressor.transformers.tracing import TraceableWhisperForConditionalGeneration
    # Select model and load it.
    model_id = "openai/whisper-large-v2"
    model = TraceableWhisperForConditionalGeneration.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype="auto",
    )
    processor = WhisperProcessor.from_pretrained(model_id)

    Next, load and tokenize a calibration dataset. For this example, we will use the MLCommons/peoples_speech dataset, which contains many audio samples and labels. You can customize the calibration dataset to reflect your use case.

    # Configure processor the dataset task.
    processor.tokenizer.set_prefix_tokens(language="en", task="transcribe")
    # Select calibration dataset.
    DATASET_ID = "MLCommons/peoples_speech"
    DATASET_SUBSET = "test"
    DATASET_SPLIT = "test"
    # Select number of samples. 512 samples is a good place to start.
    # Increasing the number of samples can improve accuracy.
    NUM_CALIBRATION_SAMPLES = 512
    MAX_SEQUENCE_LENGTH = 2048
    # Load dataset and preprocess.
    ds = load_dataset(
        DATASET_ID,
        DATASET_SUBSET,
        split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
        trust_remote_code=True,
    )
    # Preprocess and Tokenize inputs.
    def preprocess_and_tokenize(example):
        audio = example["audio"]["array"]
        sampling_rate = example["audio"]["sampling_rate"]
        text = " " + example["text"].capitalize()
        audio_inputs = processor(
            audio=audio,
            sampling_rate=sampling_rate,
            return_tensors="pt",
        )
        text_inputs = processor(
            text=text,
            add_special_tokens=True,
            return_tensors="pt"
        )
        text_inputs["decoder_input_ids"] = text_inputs["input_ids"]
        del text_inputs["input_ids"]
        return dict(**audio_inputs, **text_inputs)
    ds = ds.map(preprocess_and_tokenize, remove_columns=ds.column_names)
    # Define a oneshot data collator for multimodal inputs.
    def data_collator(batch):
        assert len(batch) == 1
        return {key: torch.tensor(value) for key, value in batch[0].items()}

    Now, you can apply one-shot recipe to quantize our model. In this case, we apply GPTQ to apply the weight quantization.

    # Recipe
    recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
    # Apply algorithms.
    oneshot(
        model=model,
        dataset=ds,
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        data_collator=data_collator,
    )
    # Save to disk compressed.
    SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
    model.save_pretrained(SAVE_DIR, save_compressed=True)
    processor.save_pretrained(SAVE_DIR)

    Finally, you can now deploy the model with vLLM for better inference performance:

    from vllm.assets.audio import AudioAsset
    from vllm import LLM, SamplingParams
    # prepare model
    llm = LLM(
        model="neuralmagic/whisper-large-v2-W4A16-G128",
        max_model_len=448,
        max_num_seqs=400,
        limit_mm_per_prompt={"audio": 1},
    )
    # prepare inputs
    inputs = {  # Test explicit encoder/decoder prompt
        "encoder_prompt": {
            "prompt": "",
            "multi_modal_data": {
                "audio": AudioAsset("winning_call").audio_and_sample_rate,
            },
        },
        "decoder_prompt": "<|startoftranscript|>",
    }
    # generate response
    print("========== SAMPLE GENERATION ==============")
    outputs = llm.generate(inputs, SamplingParams(temperature=0.0, max_tokens=64))
    print(f"PROMPT  : {outputs[0].prompt}")
    print(f"RESPONSE: {outputs[0].outputs[0].text}")
    print("==========================================")

    Model compression for multimodal AI

    LLM Compressor provides a powerful and flexible framework for compressing models, enabling faster and more efficient inference with vLLM. With the 0.4.0 release, LLM Compressor now supports quantization and sparsification of multimodal models, allowing users to efficiently scale workloads for OCR, spatial reasoning, and audio transcription/translation tasks.

    To get started, explore the latest models, recipes, and examples in the LLM Compressor repository, or experiment with quantization techniques to tailor performance to your needs.

    Ready to deploy faster, more scalable AI? Contact us to learn more about enterprise solutions or contribute to our open source journey today! 

    Related Posts

    • Getting started with InstructLab for generative AI model tuning

    • How to use LLMs in Java with LangChain4j and Quarkus

    • Level up your generative AI with LLMs and RAG

    • Experiment and test AI models with Podman AI Lab

    • Open source AI coding assistance with the Granite models

    • Enhance LLMs and streamline MLOps using InstructLab and KitOps

    Recent Posts

    • How to run a fraud detection AI model on RHEL CVMs

    • How we use software provenance at Red Hat

    • Alternatives to creating bootc images from scratch

    • How to update OpenStack Services on OpenShift

    • How to integrate vLLM inference into your macOS and iOS apps

    What’s up next?

    Download free preview chapters from Applied AI for Enterprise Java Development (O’Reilly), a practical guide for Java developers who want to build AI applications.

    Get the e-book
    Red Hat Developers logo LinkedIn YouTube Twitter Facebook

    Products

    • Red Hat Enterprise Linux
    • Red Hat OpenShift
    • Red Hat Ansible Automation Platform

    Build

    • Developer Sandbox
    • Developer Tools
    • Interactive Tutorials
    • API Catalog

    Quicklinks

    • Learning Resources
    • E-books
    • Cheat Sheets
    • Blog
    • Events
    • Newsletter

    Communicate

    • About us
    • Contact sales
    • Find a partner
    • Report a website issue
    • Site Status Dashboard
    • Report a security problem

    RED HAT DEVELOPER

    Build here. Go anywhere.

    We serve the builders. The problem solvers who create careers with code.

    Join us if you’re a developer, software engineer, web designer, front-end designer, UX designer, computer scientist, architect, tester, product manager, project manager or team lead.

    Sign me up

    Red Hat legal and privacy links

    • About Red Hat
    • Jobs
    • Events
    • Locations
    • Contact Red Hat
    • Red Hat Blog
    • Inclusion at Red Hat
    • Cool Stuff Store
    • Red Hat Summit

    Red Hat legal and privacy links

    • Privacy statement
    • Terms of use
    • All policies and guidelines
    • Digital accessibility

    Report a website issue