金曜日, 9月 12, 2025
金曜日, 9月 12, 2025
- Advertisment -
ホームニューステックニュースTriton を使って CUDA を書かずに高速な GPU カーネルを実装する

Triton を使って CUDA を書かずに高速な GPU カーネルを実装する



Triton とは

https://github.com/triton-lang/triton

Tritonは、2021年にOpenAIがリリースしたソフトウェアで、リリース記事曰く以下のような特徴があります。

an open-source Python-like programming language which enables researchers with no CUDA experience to write highly efficient GPU code—most of the time on par with what an expert would be able to produce.

(翻訳) オープンソースのPython風プログラミング言語で、CUDAの経験がない研究者でも、多くの場合専門家と同等の効率でGPU向けコードを記述できるように設計されている。

Tritonは、タイル(Tile) または ブロック(Block) と呼ばれるデータの小さな塊を効率的に処理することに特化しています。プログラマがブロック単位での演算を記述すると、Tritonコンパイラが自動的にメモリ管理、キャッシュの最適化、並列化などを行い、GPUの性能を最大限に引き出すマシンコードを生成します。

PyTorchとシームレスに連携できるため、既存のPyTorchのワークフローにカスタムカーネルを簡単に追加できるのも大きな魅力です。

実験

toy example として、マンデルブロ集合の計算が Python + CPU 1 thread で行う場合に比べてどの程度高速化されるか見てみようと思います。また、Triton よりも高レベルな jax と比べたときの速度や、実装による速度の違いなども見ていこうと思います。

Kaggle の Notebook の T2 x 2 のインスタンスで実験を行いました (GPU 0 のみ使用)。実験結果は以下のページからも確認できます。

https://www.kaggle.com/code/zaburo/compute-mandelbrot-set-with-triton

Baseline 1: Python + CPU 1 thread

https://speakerdeck.com/yuyamaguchi/kaggleniyi-li-tugao-su-hua-bing-lie-hua-tekunituku?slide=12

こちらの実装を参考にします。

max_iter = 500
xmin, xmax = -1.75, 0.75
ymin, ymax = -1.25, 1.25
width, height = 4096, 4096

def mandelbrot_kernel(c: complex, max_iter: int = 500) -> int:
    z = c
    for i in range(max_iter):
        z = z * z + c
        
        if abs(z) > 2:
            return i
    return max_iter

def compute_mandelbrot() -> None:
    image = [[0 for _ in range(width)] for _ in range(height)]
    dx = (xmax - xmin) / width
    dy = (ymax - ymin) / height

    for j in range(height):
        for i in range(width):
            y = ymin + j * dy
            x = xmin + i * dx
            image[j][i] = mandelbrot_kernel(complex(x, y))
    return image

計測結果は以下の通りでした。IPython の %%timeit を使って計測しています。

4min 26s ± 380 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Baseline 2: jax

jax は複素数も扱えるためとてもシンプルに GPU を使った処理を書くことができます。

import jax
import jax.numpy as jnp

@jax.jit
def compute_mandelbrot_jax():
    xs = jnp.linspace(xmin, xmax, width, dtype=jnp.float32)
    ys = jnp.linspace(ymin, ymax, height, dtype=jnp.float32)
    X, Y = jnp.meshgrid(xs, ys)

    C = X + 1j * Y

    Z = C
    iters = jnp.zeros(C.shape, dtype=jnp.int32)
    mask = jnp.ones(C.shape, dtype=jnp.bool)

    for i in range(max_iter):
        Z = jnp.where(mask, Z ** 2 + C, Z)
        mask = jnp.abs(Z)  2.0
        iters += mask.astype(jnp.int32)
    return iters

GPU での計算は非同期で行われるので、計測時は以下のように計算が終わるのを待つようにします。

%%timeit
img_jax = compute_mandelbrot_jax()
img_jax.block_until_ready()

計測結果は以下のとおりです

194 ms ± 452 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

(4 * 60 + 26) / 0.194 ≒ 1371 なので、この時点で1000倍以上は高速化されていますね。実用上はこれで良いというケースは多そうです。

Triton: アンチパターン

Triton の特徴として、ブロック (タイル) 単位の処理を記述することで、メモリやキャッシュ管理、並列化といった複雑な最適化をコンパイラに任せることができる、という点が挙げられます。これを体感するため、あえてブロックではなく、1点ごとの処理を記述する形にして速度を計測してみようと思います。

import torch
import triton
import triton.language as tl

@triton.jit
def mandelbrot_kernel_triton_unblocked(
    output_ptr,
    width: tl.int32,
    height: tl.int32,
    max_iter: tl.int32,
    xmin: tl.float32,
    xmax: tl.float32,
    ymin: tl.float32,
    ymax: tl.float32,
) -> None:
    pid_x = tl.program_id(axis=0)
    pid_y = tl.program_id(axis=1)

    c_real = xmin + pid_x * (xmax - xmin) / width
    c_imag = ymin + pid_y * (ymax - ymin) / height

    z_real = c_real
    z_imag = c_imag
    iters = 0
    not_converge = True

    for i in range(max_iter):
        if not_converge:
            z_real_new = z_real * z_real - z_imag * z_imag + c_real
            z_imag_new = 2 * z_real * z_imag + c_imag

            z_real = z_real_new
            z_imag = z_imag_new

            not_converge = z_real * z_real + z_imag * z_imag  4.0
            if not_converge:
                iters += 1

    offset = pid_y * width + pid_x
    tl.store(output_ptr + offset, iters)


def compute_mandelbrot_triton_unblocked(device='cuda') -> torch.Tensor:
    output = torch.empty(width * height, dtype=torch.int32, device=device)
    grid = (width, height)
    mandelbrot_kernel_triton_unblocked[grid](
        output,
        width,
        height,
        max_iter,
        xmin,
        xmax,
        ymin,
        ymax,
    )
    return output.reshape(height, width)

mandelbrot_kernel_triton_unblocked は x=pid_x, y=pid_y の 1 点の処理について記述する形になっています。

計測方法および結果は以下のとおりです。

%%timeit
img_unblocked = compute_mandelbrot_triton_unblocked()
torch.cuda.synchronize()
1.12 s ± 2.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

残念ながら jax より遅くなってしまいました。

Triton: ブロック処理


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps)
        for num_warps in [2, 4, 8, 16, 32]
        for block_size in [256, 512, 1024, 2048, 4096]
    ],
    key=["width", "height", "max_iter"],
)
@triton.jit
def mandelbrot_kernel_triton(
    output_ptr,
    width: tl.int32,
    height: tl.int32,
    max_iter: tl.int32,
    xmin: tl.float32,
    xmax: tl.float32,
    ymin: tl.float32,
    ymax: tl.float32,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    pid_y = tl.program_id(axis=1)
    pids_x = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    c_reals = xmin + pids_x * ((xmax - xmin) / width)
    c_imag = ymin + pid_y * ((ymax - ymin) / height)

    z_reals = c_reals
    z_imags = tl.full((BLOCK_SIZE,), c_imag, dtype=tl.float32)
    iters = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
    not_diverged_mask = tl.full((BLOCK_SIZE,), True, dtype=tl.int1)

    for i in range(max_iter):
        z_reals_new = z_reals * z_reals - z_imags * z_imags + c_reals
        z_imags_new = 2 * z_reals * z_imags + c_imag

        z_reals = tl.where(not_diverged_mask, z_reals_new, z_reals)
        z_imags = tl.where(not_diverged_mask, z_imags_new, z_imags)

        not_diverged_mask = (z_reals * z_reals + z_imags * z_imags)  4.0
        iters += not_diverged_mask

    mask = pids_x  width
    output_offsets = pid_y * width + pids_x
    tl.store(output_ptr + output_offsets, iters, mask=mask)


def compute_mandelbrot_triton(device='cuda'):
    output = torch.empty(width * height, dtype=torch.int32, device=device)
    grid = lambda meta: (triton.cdiv(width, meta['BLOCK_SIZE']), height)
    mandelbrot_kernel_triton[grid](
        output,
        width,
        height,
        max_iter,
        xmin,
        xmax,
        ymin,
        ymax,
    )
    return output.reshape(height, width)

mandelbrot_kernel_triton は x=pid_x, y=range(pid_y * BLOCK_SIZE, (pid_y + 1) * BLOCK_SIZE)BLOCK_SIZE 点の処理について記述する形になっています。Array に対する処理を記述していく形になるので、jax の実装と雰囲気は近くなりますね (複素数がサポートされていないので記述自体は少し長くなっていますが)

速度を詰めるために triton.autotune を使って、BLOCK_SIZEnum_warps といった速度に影響するパラメータを実行時に (同じ “width”, “height”, “max_iter” ならその最初の call で) tuning するようにしました。勘所がわかっていないため適当に広めの数字を設定しています。

計測方法および結果は以下のとおりです。

%%timeit
img_triton = compute_mandelbrot_triton()
torch.cuda.synchronize()
20.3 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

jax よりも更に10倍程度速くなりました!

まとめ

マンデルブロ集合を題材に Triton を使ってみたという記事でした。ブロック単位で処理を記述することで、効率的な GPU カーネルを作ることができました。

実のところは、マンデルブロ集合のような座標ごとに独立した計算だけで良いケースでは、CUDA でもシンプルに記述できると思うので、Triton を使うありがたみがあまりないのですが、sum などの aggregation が入ってくるような処理も CUDA に比べるとかなり高レベルな書き方ができるので、うまく使いこなして高効率なGPUカーネルを量産していきたいと思います。



Source link

Views: 0

RELATED ARTICLES

返事を書く

あなたのコメントを入力してください。
ここにあなたの名前を入力してください

- Advertisment -