Skip to main content
Redhat Developers  Logo
  • Products

    Platforms

    • Red Hat Enterprise Linux
      Red Hat Enterprise Linux Icon
    • Red Hat AI
      Red Hat AI
    • Red Hat OpenShift
      Openshift icon
    • Red Hat Ansible Automation Platform
      Ansible icon
    • View All Red Hat Products

    Featured

    • Red Hat build of OpenJDK
    • Red Hat Developer Hub
    • Red Hat JBoss Enterprise Application Platform
    • Red Hat OpenShift Dev Spaces
    • Red Hat OpenShift Local
    • Red Hat Developer Sandbox

      Try Red Hat products and technologies without setup or configuration fees for 30 days with this shared Openshift and Kubernetes cluster.
    • Try at no cost
  • Technologies

    Featured

    • AI/ML
      AI/ML Icon
    • Linux
      Linux Icon
    • Kubernetes
      Cloud icon
    • Automation
      Automation Icon showing arrows moving in a circle around a gear
    • View All Technologies
    • Programming Languages & Frameworks

      • Java
      • Python
      • JavaScript
    • System Design & Architecture

      • Red Hat architecture and design patterns
      • Microservices
      • Event-Driven Architecture
      • Databases
    • Developer Productivity

      • Developer productivity
      • Developer Tools
      • GitOps
    • Automated Data Processing

      • AI/ML
      • Data Science
      • Apache Kafka on Kubernetes
    • Platform Engineering

      • DevOps
      • DevSecOps
      • Ansible automation for applications and services
    • Secure Development & Architectures

      • Security
      • Secure coding
  • Learn

    Featured

    • Kubernetes & Cloud Native
      Openshift icon
    • Linux
      Rhel icon
    • Automation
      Ansible cloud icon
    • AI/ML
      AI/ML Icon
    • View All Learning Resources

    E-Books

    • GitOps Cookbook
    • Podman in Action
    • Kubernetes Operators
    • The Path to GitOps
    • View All E-books

    Cheat Sheets

    • Linux Commands
    • Bash Commands
    • Git
    • systemd Commands
    • View All Cheat Sheets

    Documentation

    • Product Documentation
    • API Catalog
    • Legacy Documentation
  • Developer Sandbox

    Developer Sandbox

    • Access Red Hat’s products and technologies without setup or configuration, and start developing quicker than ever before with our new, no-cost sandbox environments.
    • Explore Developer Sandbox

    Featured Developer Sandbox activities

    • Get started with your Developer Sandbox
    • OpenShift virtualization and application modernization using the Developer Sandbox
    • Explore all Developer Sandbox activities

    Ready to start developing apps?

    • Try at no cost
  • Blog
  • Events
  • Videos

GPU benchmarking and how to choose a GPU framework

August 29, 2024
Kenny Ge
Related topics:
Artificial intelligenceC, C#, C++Python
Related products:
Red Hat Enterprise Linux AI

Share:

    This is the final installment of our 4-part series on GPU programming. Catch up on the previous articles:

    • What is GPU programming

    • Your first GPU algorithm: Scan/prefix sum

    • Your second GPU algorithm: Quicksort

    With those topics covered, we can now move on to the subject of this article: how to choose a GPU framework/library.

    How to choose a GPU framework

    Generally, when you’re benchmarking some GPU code, you want to make sure to do the following:

    • Run the algorithm more than once, and get the median/quartiles, or mean/stdev of each of the running times.
    • Synchronize your CUDA kernels, to make sure the program has fully finished executing—by default CUDA will launch kernels asynchronously (in the background) so that the host can continue processing things.
    • Clear the L2 cache, usually by running a bogus script that requires enough cache memory to clear everything.
    • Warm up each of your scripts—this is especially important for Triton, because the autotuner needs to conduct a full grid search every time to determine the right hyperparameters for each run (e.g., BLOCK_SIZE, etc.). Luckily, if you're using Triton and choose to compile ahead of time, this will be done on compilation time instead. 
    • Ensure you have enough disk space—strangely, this ended up being a big deal for us on our NVIDIA machine, where Triton was severely underperforming (factor of 50x) until we cleared up enough disk space. We speculate that perhaps the Triton JIT needed disk space in order to run effectively, whereas the Torch JIT didn’t. However, we’re not entirely sure. 
    • Use torch.cuda.Event(enable_timing=True) to do timings, because this might be slightly more accurate than using time.time(), for example.

    You can see an example of the results a bad benchmark produces in Figure 1, and an example of a good benchmark in Figure 2. 

    The top line, representing Triton, performs significantly worse-- the CUDA kernel executes almost instantly in comparison
    Figure 1. This is what our benchmark looked like when we didn’t add any warmup time (notice the spike at the beginning when Triton first compiles the kernel), and also when we didn’t have enough disk space (even after adding a warmup, the difference was still just as large for all other values of the matrix size.
    Created by Kenneth Ge,
    Figure 1: This is what our benchmark looked like when we didn't add any warmup time (notice the spike at the beginning when Triton first compiles the kernel,) and also when we didn't have enough disk space (even after adding a warmup, the difference was still just as large for all other values of the matrix size).

     

    The two kernels perform very similarly, with Triton being only slightly slower at the higher matrix sizes
    Figure 1. This is what our benchmark looked like on the same machine, after only adding a warmup and clearing storage space.
    Created by Kenneth Ge,
    Figure 2: This is what our benchmark looked like on the same machine, after only adding a warmup and clearing storage space.

    Implementation details

    We can clear the cache by forcing a new array into memory, as follows:

    cache_size = 256 * 1024 * 1024
    cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda')
    cache.zero_()

    We can queue synchronized CUDA timing events as follows:

    start_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    start_event.elapsed_time(end_event)

    When we call event.record(), we essentially queue up a record event in CUDA, so that instead of needing to wait for the kernel to synchronize, as soon as the previous queued operation has completed, the CUDA driver will record the current time. 

    Differences between frameworks

    There are a handful of differences between each framework, detailed below.

    Triton

    Triton claims to be more efficient because instead of specifying individual instructions, you specify block instructions. Thus, it will automatically parallelize within each block for you, and you won’t have to worry about any data race conditions, and etc. 

    The downside is that to get the most of it, you really must use block operations. So, you must: 

    •  Be familiar with numpy/torch block ops, and how to work with numpy-style indexing.
    • Have a data workload that can be defined using only the pre-defined torch block operations.

    In certain implementations, Triton can use block operations to optimize low-precision floating point operations. For example, if you’re using 8-bit or 4-bit floats, it’s very possible that Triton block operations will automatically optimize this for you by e.g. going 64 bits at a time, rather than losing efficiency by running one thread per 4-bit item. 

    However, one really important practical consideration is that, in many of these workloads, if you’re being limited to just PyTorch/Triton block functions, it is often easier to just use torch.compile to build the kernel automatically, or even skip the kernel altogether and just use a series of torch.XX operations instead. Also, if you end up doing any custom functions (such as Conway’s Game of Life), you may not be able to take full advantage of the automatically parallelized bock operations. 

    Also, race conditions can still occur if they happen across blocks. For example, when I was developing a matrix multiplication algorithm, I came across a race condition with the following code:

    @triton.jit
    def matrix_mult(A, B, C, M, K, N, BLOCK_SIZE_M : tl.constexpr, BLOCK_SIZE_K : tl.constexpr, BLOCK_SIZE_N : tl.constexpr):
        """
        C := A x B
        """
        pid_m = tl.program_id(axis=0)
        pid_k = tl.program_id(axis=1)
        pid_n = tl.program_id(axis=2)
    
        # current block m-wise
        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        # current block n-wise
        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        # current block k-wise
        offs_k = (pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
    
        # get m x k block in A
        a_ptrs = A + (offs_am[:, None] * K + offs_k[None, :] * 1)
        # get k x n block in B
        b_ptrs = B + (offs_k[:, None] * N + offs_bn[None, :] * 1)
        
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0.0)
    
        c_start = C + BLOCK_SIZE_M * pid_m * N + pid_n * BLOCK_SIZE_N * 1
    
        this_block = c_start + tl.arange(0, BLOCK_SIZE_M)[:, None] * N + tl.arange(0, BLOCK_SIZE_N)[None, :] * 1
    
        # We accumulate along the K dimension.
        ans = tl.dot(a, b)
        tl.store(this_block, tl.load(this_block) + ans)

    This is because multiple blocks were accessing the same block indices in our output array C, and attempting to each read, add, and write to this block. The best way to avoid this is unfortunately to redesign your code. A more thorough review of Triton matrix multiplication and performance optimizations without race conditions can be found here. 

    OpenCL

    Unlike Triton, OpenCL is a lot more low-level. Even in PyOpenCL, you end up writing GLSL shader code for kernels, and you also need to worry about setting up a context etc. On the other hand, with Triton, you simply use PyTorch as normal, and simply make sure to move your memory over to the device (i.e., device=’cuda’). 

    However, OpenCL can actually feel easier to code for certain tasks, because you write kernels that execute for every thread, rather than for every block like in Triton. This makes it easier for reason about what you’re doing in situations like this.

    Triton:

    @triton.jit
    def scan(Y, nextY, stride, BLOCK_SIZE: tl.constexpr):
        pid_row = tl.program_id(0)
    
        for j in tl.static_range(BLOCK_SIZE):
            current_idx = pid_row * BLOCK_SIZE + j
            if current_idx - stride >= 0:
                Yj = tl.load(Y + current_idx)
                Yjminstride = tl.load(Y + current_idx - stride)
                
                tl.store(nextY + current_idx, Yj + Yjminstride)
            else:
                tl.store(nextY + current_idx, tl.load(Y + current_idx))

    OpenCL:

    __kernel void psum_step(__global float *src, __global float *dest, int stride)
    {
        int gid = get_global_id(0);
    
        if(gid - stride >= 0)
            dest[gid] = src[gid] + src[gid - stride];
    }

    The above is a kernel representing one step in a prefix sum. Notice that in Triton you have to worry about the block operation, whereas in OpenCL you can assume that you’re working directly on each individual index. Both approaches will have their pluses and minuses for different types of problems. Also notice that OpenCL allows for much more convenient syntax, whereas in Triton you must use fairly convoluted operations like tl.store, tl.load, tl.static_range, and etc. 

    Notice that it’s not clear whether the Triton version is necessarily even more efficient—because it works on a per-block basis, it is entirely possible that the entire function is just running on a single thread, which has to process the entire block size. The Triton team claims that they have done some serious work under the hood in terms of automatic parallelism and compiler optimizations, but due to a lack of documentation, it is difficult to determine which custom block operations will be fully parallelized. 

    However, one of the disadvantages of working at such a low level is needing to worry about corrupted data if your types don’t match. Take a look at the following code:

    import numpy as np
    import pyopencl as cl
    
    ctx = cl.create_some_context()
    queue = cl.CommandQueue(ctx)
    
    mf = cl.mem_flags
    
    prg = cl.Program(ctx, """__kernel void fill(__global float *res_g)
    {
        int gid = get_global_id(0);
    
        res_g[gid] = 1.0;
    }""").build()
    
    # create numpy starting array
    a_np = np.arange(8)
    a_g = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a_np)
    res_g = cl.Buffer(ctx, mf.READ_WRITE, a_np.nbytes)
    res_np = np.empty_like(a_np)
    
    prg.fill(queue, res_np.shape, None, res_g)
    
    cl.enqueue_copy(queue, res_np, res_g)
    
    print(res_np)
    print(a_np)

    This benign-looking program simply creates a numpy array using the arange function, then creates another buffer the same size, called res_np, and attempts to use the fill kernel to fill all elements in res_g to 1.0. The output is as follows:

    [4575657222473777152 4575657222473777152 4575657222473777152
     4575657222473777152                   0                   0
                       0                   0]
    [0 1 2 3 4 5 6 7]

    The kernel didn’t work! In fact, the data in the resulting buffer is completely garbled. It turns out, the reason this didn’t work is because our kernel expects to take in 32-bit floats, but np.arange uses int64 by default. If we simply add the line a_np = a_np.astype(np.float32), we will get the correct answer. This bug is entirely non-obvious. In fact, I spent about two days getting stuck here. 

    Numba

    Surprisingly, Numba is actually one of the most highly recommended and well-documented tools that Nvidia mentions in its Python GPU programming page. 

    It combines the low-level control of CUDA or OpenCL, with the ease of Python. It functions similar to Triton in that you define a function with a custom annotation, and don’t have to deal with an environment other than moving data to/from the GPU. 

    Overall, in a practical setting, I would recommend using Numba for a few reasons:

    • Numba is incredibly well-supported, on both the GPU and the CPU side. So this library is going to likely be less buggy, and more likely to work.
    • Because it transpiles down to CUDA/HCC code under the hood, it’s more likely to benefit from optimizations, that are otherwise unavailable to OpenCL due to its age and neglect.
    • It still gives you that same level of control that CUDA/OpenCL does, without forcing you to use the block operations that Triton does. Not that block operations are bad—but the optimizations come at the cost of potentially more difficult development.

    Here’s a Numba kernel for the prefix sum:

    @cuda.jit
    def numba_psum(input_array, output_array, stride, N):
        # Thread id in a 1D block
        tx = cuda.threadIdx.x
        # Block id in a 1D grid
        ty = cuda.blockIdx.x
        # Block width, i.e. number of threads per block
        bw = cuda.blockDim.x
        # Compute flattened index inside the array
        pos = tx + ty * bw
    
        back_pos = pos - stride
    
        if pos >= N:
            return
        
        if back_pos < 0:
            output_array[pos] = input_array[pos]
        else:
            output_array[pos] = input_array[pos] + input_array[back_pos]

    Conclusion

    Congrats! You've completed this four-part GPU getting started guide. Now, you have a good grasp on some of the foundational algorithms and techniques used in parallel programming. With these tools, I hope you will be one step closer toward writing the next generation of high-performance algorithms. 

    Related Posts

    • What is GPU programming?

    • Your first GPU algorithm: Scan/prefix sum

    • Your second GPU algorithm: Quicksort

    • GPU enablement on MicroShift

    • Why GPUs are essential for AI and high-performance computing

    • Intel GPUs and OVMS: A winning combination for deep learning efficiency

    Recent Posts

    • Migrating Ansible Automation Platform 2.4 to 2.5

    • Multicluster resiliency with global load balancing and mesh federation

    • Simplify local prototyping with Camel JBang infrastructure

    • Smart deployments at scale: Leveraging ApplicationSets and Helm with cluster labels in Red Hat Advanced Cluster Management for Kubernetes

    • How to verify container signatures in disconnected OpenShift

    What’s up next?

    In this learning path, you will configure a Jupyter notebook to use GPUs for AI/ML modeling and learn how to use PyTorch to examine GPU resources, then load and run a PyTorch model.

    Start the activity
    Red Hat Developers logo LinkedIn YouTube Twitter Facebook

    Products

    • Red Hat Enterprise Linux
    • Red Hat OpenShift
    • Red Hat Ansible Automation Platform

    Build

    • Developer Sandbox
    • Developer Tools
    • Interactive Tutorials
    • API Catalog

    Quicklinks

    • Learning Resources
    • E-books
    • Cheat Sheets
    • Blog
    • Events
    • Newsletter

    Communicate

    • About us
    • Contact sales
    • Find a partner
    • Report a website issue
    • Site Status Dashboard
    • Report a security problem

    RED HAT DEVELOPER

    Build here. Go anywhere.

    We serve the builders. The problem solvers who create careers with code.

    Join us if you’re a developer, software engineer, web designer, front-end designer, UX designer, computer scientist, architect, tester, product manager, project manager or team lead.

    Sign me up

    Red Hat legal and privacy links

    • About Red Hat
    • Jobs
    • Events
    • Locations
    • Contact Red Hat
    • Red Hat Blog
    • Inclusion at Red Hat
    • Cool Stuff Store
    • Red Hat Summit
    © 2025 Red Hat

    Red Hat legal and privacy links

    • Privacy statement
    • Terms of use
    • All policies and guidelines
    • Digital accessibility

    Report a website issue