Triton Language 101

Recently, I have had the opportunity of using OpenAI’s Triton language for GPU kernel development. As I was learning Triton, I found that the documentation was often difficult for me to follow, consisting almost entirely in API reference material and end-to-end code examples of real-world kernels which can gloss over some of the interesting nuances, and which offer little distinction between the features of the language and common conventions followed by Triton programmers.

My goal with this article is to present a more approachable introduction to Triton, stepping past the “how” (as in, how to write a matrix multiplication kernel, how to call a particular Triton API, etc.) and examining the “whats” and “whys”, demisitifying, hopefully, this powerful and interesting language.

Writing Triton

The great thing about writing Triton kernels, as opposed to writing in pure CUDA, is that we can stay entirely within Python; there is no separate compilation or build process to setup, and no Python bindings. We simply mark a function as a Triton kernel by decorating it with @triton.jit.

import triton

@triton.jit
def hello_triton():
    print("Hello from Triton!")

If we try to run this kernel as if it were any other python function, however, we will see an error:

>>> hello_triton()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python3/dist-packages/triton/runtime/jit.py", line 713, in __call__
    raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
RuntimeError: Cannot call @triton.jit'd outside of the scope of a kernel

Launch Parameters

Like their CUDA counterparts, Triton kernels require us to specify launch parameters at the time of execution. Recall that CUDA kernels launch many threads that execute instructions in parallel; these threads are organized into 3-dimensional thread blocks; these thread blocks, in turn, are organized into a 3-dimensional launch grid.

// TODO - threads, blocks, grid visual

Launch parameters in CUDA are specified using the triple chevron notation (i.e. with thread block and launch grid dimensions provided within enclosing <<<...>>> signs). With Triton, we do something similar(ish) by providing the launch grid dimensions, as a tuple, within square brackets ([...]) between the function name and parameters, like this:

launch_grid = (1,1,1) # 3-dimensional 1x1x1 grid
hello_triton[launch_grid]()
pid (0, 0, 0) Hello from Triton!
pid (0, 0, 0) Hello from Triton!
pid (0, 0, 0) Hello from Triton!
...
pid (0, 0, 0) Hello from Triton!
pid (0, 0, 0) Hello from Triton!
pid (0, 0, 0) Hello from Triton!

With the launch grid dimensions specified, we can now run the Triton kernel without error, and we see the "Hello from Triton! message print several times (128 times to be exact, but more on this shortly). Before moving on, let’s linger for a moment on this peculiar square bracket syntax. There are two aspscts of this method for passing launch parameters that felt odd to me at first.

(1) Is this even valid Python?

We don’t normally see code like this in plain Python. i.e.

function_name[...](...)

CUDA extends the C++ language to add support for things like the triple chevron notation, so we might be forgiven for thinking something similar is happening with Triton and Python. This Triton launch parameter notation is, however, valid Python. Think of a dict where each key is a tuple, and each value is a callable. e.g.

my_strange_dict = {
    (1,2,3): lambda x: x+1
}

In this case, it is perfectly valid to call the underlying lambda just like calling a Triton kernel, with:

my_strange_dict[(1,2,3)](2)
3

So does this meaning our Triton kernel function, hello_triton, is now a dictionary instead of a function? In a way, yes. When we add the @triton.jit decorator to our function, Triton converts this function into something called a JITFunction.

type(hello_triton)
triton.runtime.jit.JITFunction

A quick search in the Triton github repository shows that JITFunction inherits from KernelInterface, which implements the built-in __getitem__ function used to implement dictionary or list-like indexing using square brackets.

# From https://github.com/triton-lang/triton/blob/7c56a5e40f7fd928dfd5c72902d5def0097db73a/python/triton/runtime/jit.py#L355-L370
class JITFunction(JITCallable, KernelInterface[T]):
    ...

class KernelInterface(Generic[T]):
    run: T

    ...

    def __getitem__(self, grid) -> T:
        """
        A JIT function is launched with: fn[grid](*args, **kwargs).
        Hence JITFunction.__getitem__ returns a callable proxy that
        memorizes the grid.
        """
        return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

What we get by passing launch grid dimensions in via square brackets, therefore, is a callable that invokes an underlying self.run function. This means we could avoid this quirky syntax by invoking self.run directly:

launch_grid = (1,1,1)
hello_triton.run(grid=launch_grid, warmup=False)

In any case, the function_name[(...grid dimensions...)](...kernel parameters...) pattern is followed in any Triton code you will find in the wild, so we stick with this going forward.

(2) Where are the Thread Block Dimensions?

In CUDA, there are 2 sets of 3 dimensions that we provide at every kernel launch, the launch grid dimensions and the thread block dimensions. In Triton, however, we only ever provide the launch grid dimensions. What are the thread block dimensions then? and how do we specify then? There is a clue to the first part of the question if count the number of times the print statement in the hello_triton kernel is executed.

python -c "from hello import hello_triton; hello_triton[(1,1,1)]()" | wc -l 
128

The print statement is executed 128 times, indicating that there are 128 threads in the single thread block that we launched (only a single thread block was launch because our grid dimensions were 1x1x1).

The full answer to the question about thread block dimensions can be found by searching through the Triton reference documentation. Every kernel accepts an argument called num_warps, which, according to the docs, is the number of warps to use for the kernel when compiled for GPUs.

In GPU programming, a warp is a collection of 32 threads executing the same kernel code in parallel. Thus, the number of threads in each thread block is computed as num_warps * 32. We can test this out by passing num_warps into our hello_triton kernel and seeing how many times the print statement is executed.

python -c "from hello import hello_triton; hello_triton[(1,1,1)](num_warps=1)" | wc -l
32

Implementing a Kernel

Now that we understand how to define and invoke a Triton kernel, let’s look at how we can actually implement a kernel that does something useful. As a minimal example, we will write a kernel to increment every element in a given array.

The simplest way to pass data into a Triton kernel is to use PyTorch tensors, as these can be easily moved to GPU memory and support all the functionality that Triton expects (there are workarounds to get Triton working with arrays from other frameworks like Jax or CuPY).

We pass tensors into Triton kernels just like we would into any other Python function:

import triton, torch

@triton.jit()
def increment_kernel(a):
    print("Array (in Triton):", a, hex=True)

if __name__ == "__main__":
    # Set default device to CUDA so PyTorch tensors are 
    # stored on GPU by default
    torch.set_default_device("cuda")
    array = torch.tensor([1,2,3,4], dtype=torch.int32)
    grid = (1,1,1)
    increment_kernel[grid](array, num_warps=1)

Tensor Pointers

One of the first and most important nuances of Triton kernel development you will encounter is that, when we pass a PyTorch tensor to a Triton kernel, Triton doesn’t see a tensor object; instead, it only sees a pointer to the first element of that tensor. We can see what Triton sees by calling the data_ptr() function of the tensor outside of the kernel, or by printing the tensor within the kernel:

import triton, torch

@triton.jit()
def increment_kernel(a):
    print("Array (in Triton):", a, hex=True)

if __name__ == "__main__":
    torch.set_default_device("cuda")
    array = torch.tensor([1,2,3,4], dtype=torch.int32)

    grid = (1,1,1)
    print("="*60)
    print("Array ptr (in Python):", hex(array.data_ptr()))
    print("="*60)
    increment_kernel[grid](array, num_warps=1)
============================================================
Array ptr (in Python): 0x7b0b3aa00000
============================================================
pid (0, 0, 0) idx () Array (in Triton): 0x7b0b3aa00000
pid (0, 0, 0) idx () Array (in Triton): 0x7b0b3aa00000
pid (0, 0, 0) idx () Array (in Triton): 0x7b0b3aa00000
...
pid (0, 0, 0) idx () Array (in Triton): 0x7b0b3aa00000
pid (0, 0, 0) idx () Array (in Triton): 0x7b0b3aa00000
pid (0, 0, 0) idx () Array (in Triton): 0x7b0b3aa00000

Here, what is passed into the Triton kernel is a pointer to the memory address 0x7b0b3aa00000, the same memory address returned by the PyTorch data_ptr() function. We can also try passing different slices of the tensor to the kernel, to observe how that affects the pointer that Triton ultimately sees. For example, if we pass in a slice that omits the first element, we will see that the pointer in Triton is offset from the original data_ptr by 4 bytes:

============================================================
Array ptr (in Python): 0x780154a00000
============================================================
pid (0, 0, 0) idx () Array (in Triton): 0x780154a00004
pid (0, 0, 0) idx () Array (in Triton): 0x780154a00004
pid (0, 0, 0) idx () Array (in Triton): 0x780154a00004
...
pid (0, 0, 0) idx () Array (in Triton): 0x780154a00004
pid (0, 0, 0) idx () Array (in Triton): 0x780154a00004
pid (0, 0, 0) idx () Array (in Triton): 0x780154a00004

This is as expected: each element in the tensor occupies 32 bits (or 4 bytes), as the tensor was initialized with dtype=torch.int32. Therefore the second element in the tensor is stored at an address 4 bytes removed from the first element.

Loading Data

Now that we know how to pass a pointer to a tensor into our Triton kernel, we next need a way to load the underlying data stored at the location indicated by the pointer.

This is the purpose of the triton.language.load function, which accepts a pointer as input. Let’s add a call triton.language.load to our kernel to load data from the pointer:

import triton, torch
import triton.language as tl

@triton.jit()
def increment_kernel(a):
    print("Array (in Triton):", a, hex=True)
    a_data = tl.load(a)
    print("a_data (in Triton):", a_data)

If we execute this kernel with the new print statement added, the output will look like this:

============================================================
Array ptr (in Python): 0x773de4a00000
============================================================
pid (0, 0, 0) idx () Array (in Triton): 0x773de4a00000
pid (0, 0, 0) idx () Array (in Triton): 0x773de4a00000
pid (0, 0, 0) idx () Array (in Triton): 0x773de4a00000
pid (0, 0, 0) idx () Array (in Triton): 0x773de4a00000
...
pid (0, 0, 0) idx () a_data (in Triton): 1
pid (0, 0, 0) idx () a_data (in Triton): 1
pid (0, 0, 0) idx () a_data (in Triton): 1
pid (0, 0, 0) idx () a_data (in Triton): 1

After calling load(), a_data is populated with the actual data elements indicated by the pointer (i.e. the first element in the tensor). We can use pointer arithmetic to tell load() to read from a different memory location.

But now we have another problem: each thread loaded the exact same element (the first element, 1); the entire point of running compute on GPUs is to parallelize our computation, meaning that we generally want each thread to operate on some subset of the input.

In CUDA, we accomplish this by using the thread index to uniquely identify each thread, which allows us to assign each thread some particular portion of the computation. In Triton, we don’t have thread indices.