ericjwang.com

Accelerating transformer inference on my RTX 4090

I haven’t had as much time to keep working on this project as I’d have liked. If I had a bit more time, I’d have added sections for flash attention, speculative decoding, and more aggressive quantization methods. May revisit eventually…

I recently purchased an RTX 4090 GPU. Frankly, I’m not sure why I did this, because for most of my personal and professional computing needs I’m perfectly satisfied with my M1 Macbook Pro, and I don’t really enjoy playing video games. If all I wanted was to mess around with deep models, I’d have been better served running a VM on the Lambda cloud, where for the same amount of money one can run a 4090-equivalent datacenter GPU for almost a year — enough time, I expect, for something even better to come out.

Still, one must justify the unjustifiable. I’ve always been interested in ML inference at the systems level, and owning a GPU is a good way to act on that interest. So for the past week, on the advice of my friend Horace He, I’ve been spending my leisure time gradually accelerating the inference of a basic GPT model.

I chose to begin with Andrej Karpathy’s nanoGPT, which is a concise but complete implementation of the GPT model. Within nanoGPT, the model is decomposed into the following PyTorch modules:

For more details on this architecture pray consult the previous post.

My plan is to rewrite these modules gradually, specifically to accelerate inference rather than training. To guide this project and measure its success, I need a fixed benchmark. I propose the following:

A reasonable ansatz is that the distribution of the time this task takes is approximately normal, so we can defer the question about sample sizes to later. In a file called harness.py we write something like:

def benchmark(
    gpt: GPT,
    batch_size=1,
    sample_size=1
) -> float:
    return [
        sample(
            gpt,
            batch_size,
            prompt_tokens=256,
            sampled_tokens=256
        )
        for _ in tqdm(range(sample_size))
    ]


def sample(
    gpt: GPT,
    batch_size=1,
    prompt_tokens=256,
    sampled_tokens=256,
    get_rand_input=get_rand_input,
) -> float:
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    u = get_rand_input(batch_size, prompt_tokens)
    cuda_u = torch.from_numpy(u).cuda()

    # run inference
    torch.cuda.synchronize()
    start.record()
    idx = gpt.generate(cuda_u, sampled_tokens)
    end.record()
    torch.cuda.synchronize()

    # visualize outputs
    idx = idx.to("cpu").numpy()
    for i in range(batch_size):
        print(
            decode(u[i]),
            "🩹",
            decode(idx[i][len(u[i]) :])
        )
    return start.elapsed_time(end) / 1e3


def get_rand_input(batch_size, seq_len):
    md_tokens = get_moby_dick_tokens() # memoized
    md_len = len(md_tokens)
    return np.array(
        [
            md_tokens[start_idx : start_idx + seq_len]
            for start_idx in (
                random.randint(0, md_len - seq_len)
                for _ in range(batch_size)
            )
        ]
    )

When I run the full version of this code on gpt2-xl with a prompt length of just 128, the GPU issues forth a long croak of coil whine, and spits out:

loading weights from pretrained gpt: gpt2-xl
number of parameters: 1557.61M
Moby Dick has 305318 tokens

== SAMPLED TEXT FOLLOWS ==

 on her way north-eastward towards the island of Java;
 a gentle air impelling her keel, so that in the surro
unding serenity her three tall tapering masts mildly w
aved to that languid breeze, as three mild palms on a
plain. And still, at wide intervals in the silvery nig
ht, the lonely, alluring jet would be seen.

But one transparent blue morning, when a stillness alm
ost preternatural spread over the sea, however unatten
ded with any stagnant calm; when the long burnished su
n-glade on the waters seemed a golden finger laid acro
ss them, enjoining some secrecy 🩹 and secrecy it did
not admit of; when those spectacles of azure were star
ing through the panes, without the least bloodlust, ar
ound her passing from star to oar, which she could mee
t as those she ran, visible to him; her lewness of her
 shiny of sea-silken've is never mentioned by any, sav
 e in some tenth case; And, after thus maintaining a r
 emote silence at anchor, on daybreak she appears to r
 eason on the outside!

At an instant, as she adjusts her masts, the deserted
hulk of a powerful ship pierces her leeward. Huge colu
mns of

== END OF SAMPLED TEXT ==

time: 11.12s

Too late do I now realize that I chose the wrong text with which to test model correctness. It is, unfortunately, well within the realm of possibility for Melville to have written of the Pequod’s lew shiny of sea-silken’ve — but even if correctness is indistinguishable from gibberish, we can be assured that any subsequent implementations of the model are correct as long as their final activations match nanoGPT’s.


Surface observations about bottlenecks

A glance at nvtop gives us the story on the inference above:

First, memory usage climbs to 6 GiB. Once the model is loaded, the GPU’s utilization shoots up to 99%. Then, once inference is finished, 11.12 seconds later, it drops back down to zero.

It’s reasonable to assume that we’re compute-bottlenecked right now, because nvidia-smi is telling us that that GPU utilization is at 99% and we have plenty of memory to spare. After all, 1557.61M parameters of four bytes each makes 5.8 GiB, which explains the memory utilization figure. Doubling the batch size more or less doubles the time that inference takes, suggesting that a single example already requires too many FLOPs to vectorize appropriately.

To figure out where we can squeeze out more compute efficiency, then, let’s review the GPT architecture. In the diagram below, which I’ve lovingly rendered with a modded-out implementation of mermaid, inputs are marked in green, parameters are marked in blue, matmuls are marked in red, and the output of a module is its unique topmost value.

%%{init: { "flowchart": {"useMaxWidth": true, "rankSpacing": 20} } }%%

flowchart BT
subgraph GPT
direction BT
classDef weight fill:#88f,stroke:#00a
classDef input fill:#8f8,stroke:#0a0
classDef matmul fill:#f88,stroke:#a00

Wte["W<sub>te</sub>"] --> WteU["W<sub>te</sub>U<sub>&lt;t</sub>"]
class WteU matmul
Wte --> Logits
class Wte weight
U["U<sub>&lt;t</sub>"] --> WteU
class U input
WteU --> Embedding["W<sub>te</sub>U<sub>&lt;t</sub> + W<sub>pe</sub> "]
Wpe["W<sub>pe</sub>"] ---> Embedding
class Wpe weight;
Embedding --> Block
subgraph Block["Block (x <i>n<sub>l</sub></i>)"]
    direction BT
    xb --> LayerNorm1[LayerNorm] --> CausalSelfAttention
    xb[X] --> resid1["X = X + CSA(X)"]
    class xb input
    CausalSelfAttention --> resid1
    subgraph CausalSelfAttention
        direction BT
        Wqkv["W<sub>qkv</sub>"] --> QKV["W<sub>qkv</sub>X"]
        class Wqkv weight
        xcsa[X] --> QKV
        class QKV matmul
        class xcsa input
        QKV --> Q["Q"]
        QKV --> K["K"]
        QKV --> V["V"]
        Q --> QK["QK<sup>T</sup>"]
        K --> QK
        class QK matmul
        QK --> mQK["Mask(QK<sup>T</sup>)/√d<sub>m</sub>"] --> sQK["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)"]
        sQK --> attn
        V -----> attn["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)V"]
        class attn matmul
    end
    resid1 --> LayerNorm2[LayerNorm] --> MLP
    resid1 --> resid2["X = X + MLP(X)"]
    MLP --> resid2
    subgraph MLP
        direction BT
        xm[X] --> layer1[W<sub>1</sub>x]
        class xm input
        W1[W<sub>1</sub>] --> layer1
        layer1 --> hlayer1["gelu(W<sub>1</sub>x)"]
        class layer1 matmul
        W2[W<sub>2</sub>] ----> layer2
        hlayer1 --> layer2["W<sub>2</sub>gelu(W<sub>1</sub>x)"]
        class layer2 matmul
        class W1 weight
        class W2 weight
    end
end
Block --> Logits[W<sub>te</sub><sup>T</sup>Z]
class Logits matmul
Logits --> Multinomial["U<sub>t</sub> = Multinomial(exp(W<sub>te</sub><sup>T</sup>Z))"]
end

We care about matmuls here because all of the other nodes in the computation graph take \(O(mn)\) FLOPs, while naïve matrix multiplication \(\mathbb{R}^{m\times n} \times \mathbb{R}^{n \times p} \rightarrow \mathbb{R}^{m \times p}\) takes \(2mnp = O(mnp)\) FLOPs3, which heuristically suggests that they should account for the largest share.

Here are our five matmuls, with \(t\) being the sequence length and all other variables as they are in the previous post:

The FLOP requirement of a single forward pass due to matmuls is therefore approximately

\[\begin{align}F(t) &= 6n_\ell td_m^2 + 4n_\ell t^2 d_m + 16n_\ell td_m^2 + 2n_vd_m \\&=22n_\ell td_m^2 + 4n_\ell t^2d_m + 2n_vd_m \\&=(4n_\ell d_m) t^2 + (22n_\ell d_m^2)t + 2n_vd_m. \end{align}\]

The XL model has \(n_\ell = 48, d_m = 1600, n_v=50257\), so our polynomial simplifies to

\[F(t) \approx (3.1\times 10^5) t^2 + (2.7 \times 10^9)t + 1.6 \times 10^8.\]

Here are some values of this function:

\(t\) QKV Attn MLP LM head \(F(t)\)
\(256\) \(1.89 \times 10^{11}\) \(2.01 \times 10^{10}\) \(5.03 \times 10^{11}\) \(1.61 \times 10^{8}\) \(7.12 \times 10^{11}\)
\(512\) \(3.77 \times 10^{11}\) \(8.05 \times 10^{10}\) \(1.01 \times 10^{12}\) \(1.61 \times 10^{8}\) \(1.46 \times 10^{12}\)
\(1024\) \(7.55 \times 10^{11}\) \(3.22 \times 10^{11}\) \(2.01 \times 10^{12}\) \(1.61 \times 10^{8}\) \(3.09 \times 10^{12}\)
\(2048\) \(1.51 \times 10^{12}\) \(1.29 \times 10^{12}\) \(4.03 \times 10^{12}\) \(1.61 \times 10^{8}\) \(6.83 \times 10^{12}\)

We see that forward passes tend to take several teraflops each.

Because the nanoGPT autoregressive sampling code evaluates the forward pass for \(256 \leq t < 512\), the entire model should take at least \(\sum_{t = 256}^{511} F(t) \approx 2.77 \times 10^{14}\) FLOPs, or 277 TFLOPs. (The RTX 4090 spec says it has a peak FP32 TFLOPS of 82.6 on the boost clock, but I won’t pretend I’m using the GPU anywhere near optimally yet.)

This suggests our first optimization — an easy one at the algorithmic level, which requires little knowledge of how the hardware works but slashes the FLOP complexity from \(\Theta(n^3)\) to \(\Theta(n^2)\).


Day 1: Memoizing Causal Self-Attention (KV caching)

It is often said of causal self-attention that “later tokens in a sequence do not affect the embeddings of earlier ones.” In more concrete terms: if we let the output of the CausalSelfAttention block be the matrix \(A = \{a_{ij}\} \in \mathbb{R}^{t \times d_m}\), we can compute each entry \(a_{ij}\) using just the values in \(X_{\leq i}\in \mathbb{R}^{i \times d_m}\).

It is enlightening to examine why this is the case. Below is the architecture of causal self-attention:

%%{init: { "flowchart": {"useMaxWidth": true, "rankSpacing": 30, "nodeSpacing": 100} } }%%
flowchart BT
classDef weight fill:#88f,stroke:#00a
classDef input fill:#8f8,stroke:#0a0
classDef matmul fill:#f88,stroke:#a00

    Wqkv["W<sub>qkv</sub>"] --> QKV["W<sub>qkv</sub>X"]
    class Wqkv weight
    xcsa[X] --> QKV
    class QKV matmul
    class xcsa input
    QKV --> Q["Q"]
    QKV --> K["K"]
    QKV --> V["V"]
    Q --> QK["QK<sup>T</sup>"]
    K --> QK
    class QK matmul
    QK --> mQK["Mask(QK<sup>T</sup>)/√d<sub>m</sub>"] --> sQK["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)"]
    sQK --> attn
    V -----> attn["Softmax(Mask(QK<sup>T</sup>)/√d<sub>m</sub>)V"]
    class attn matmul

Imagine passing in the two matrices \(X_{<t} = \left(\begin{smallmatrix}x_1^T \\ \vdots \\ x_{t - 1}^T \end{smallmatrix}\right) \in \mathbb{R}^{(t - 1) \times d_m}\) and \(X_{\leq t} = \left(\begin{smallmatrix}X_{<t} \\ x_{t}^T \end{smallmatrix}\right)\in \mathbb{R}^{t \times d_m}\). How does each node on the graph differ?

That is, each incremental sequence entry \(x_t\) adds the entry \(\mathrm{Softmax}(d^{-1/2}_mq_t^T K_{\leq t}^T)V_{\leq t}\) to the output attention matrix. If we can memoize \(K_{<t}\) and \(V_{<t}\) from a previous forward pass of the model, self-attention will only require the matrix multiplications

Thus, we have gone from \(4t^2d_m+ 6td_m^2 + O(d_m + t)\) to \(4td_m + 6d_m^2 + O(d_m)\) FLOPs in the incremental block. This technique has been referred to elsewhere as KV caching.


flowchart BT
classDef weight fill:#88f,stroke:#00a
classDef input fill:#8f8,stroke:#0a0
classDef matmul fill:#f88,stroke:#a00

subgraph CausalSelfAttention
    direction BT
    Wqkv["W<sub>qkv</sub>"] --> QKV["W<sub>qkv</sub>X"]
    class Wqkv weight
    xcsa[X] --> QKV
    class QKV matmul
    class xcsa input
    QKV --> Q["Q<sub>&lt;t</sub>"]
    QKV --> K["K<sub>&lt;t</sub>"]
    QKV --> V["V<sub>&lt;t</sub>"]
    Q --> QK["Q<sub>&lt;t</sub>K<sub>&lt;t</sub><sup>T</sup>"]
    K --> QK
    class QK matmul
    QK --> mQK["Mask(Q<sub>&lt;t</sub>K<sub>&lt;t</sub><sup>T</sup>)"] --> sQK["Softmax(Mask(Q<sub>&lt;t</sub>K<sub>&lt;t</sub><sup>T</sup>)/√d<sub>m</sub>)"]
    sQK --> attn
    V -----> attn["Softmax(Mask(Q<sub>&lt;t</sub>K<sub>&lt;t</sub><sup>T</sup>)/√d<sub>m</sub>)V<sub>&lt;t</sub>"]
    class attn matmul
end

subgraph IncrementalCausalSelfAttention
    direction BT
    _Wqkv["W<sub>qkv</sub>"] --> _QKV["W<sub>qkv</sub>x<sub>t</sub><sup>T</sup>"]
    class _Wqkv weight
    _xcsa[x] --> _QKV
    class _QKV matmul
    class _xcsa input
    _QKV --> _Q["q<sub>t</sub>"]
    _QKV --> _K["k<sub>t</sub>"]
    _QKV --> _V["v<sub>t</sub>"]
    _Q ---> _QK["q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>"]
    _K --> __K
    K --> __K["K<sub>≤t</sub>"] -->_QK
    class _QK matmul
    _QK --> _mQK["Mask(q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>)/√d<sub>m</sub>"] --> _sQK["Softmax(Mask(q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>)/√d<sub>m</sub>)"]
    _sQK --> _attn
    _V --> __V
    V --> __V["V<sub>≤t</sub>"] -----> _attn["Softmax(Mask(q<sub>t</sub><sup>T</sup>K<sub>≤t</sub><sup>T</sup>)/√d<sub>m</sub>)V<sub>≤t</sub>"]
    class _attn matmul
end

A model that implements KV caching would run a first pass with the entire prompt to populate the cache, then a series of single-token passes that incrementally build on the cached values. There is no need to modify the feed-forward layer because it doesn’t involve interactions between multiple \(t\), but we should make sure to offset the positional embeddings \(W_{pe}\) accordingly.

We implement KV caching in implementations/memoized.py. We override the base implementation of CausalSelfAttention with a MemoizedCausalSelfAttention that stores \(K_{<t}\) and \(V_{<t}\) as buffers, and augment it with the ability to function both as a normal self-attention layer, overwriting the cache, and as an incremental self-attention layer, reading from the cache and writing \(k_t\) and \(v_t\) back to it. Then we adapt the generate code in the base class to use our incremental self-attention functionality, and add our new MemoizedGPT to our test harness. The code is here.


Interlude: FLOP review

Let’s review where we are FLOP-wise. Our current sampling procedure is to run one “full” forward pass with \(t_{min} = 256\) tokens to populate the \(KV\) cache, then \(t_{max} - (t_{min} + 1) = 255\) “incremental” forward passes to sample autoregressively. Thus:

This works out to a total of 0.12 TFLOPs, broken down as follows:

From this calculation, we can draw the following observations:

To the best of my knowledge, the only obvious inference optimization remaining that reduces FLOPs is to fuse the mask to the matmul in the “full” attention step for an 8.4% speedup. (I could think of a few architectural changes that could get us further, like sparse attention or removing the bias on the \(QKV\) network.) We’ll get to it in due time, but it’s time to start thinking about making better use of the specific features of our hardware.


Day 2: Quantization (fp16, fp8)

KV caching gives us a marked improvement in performance. First, the time taken for the execution of a single task falls from 11 seconds to 7 seconds. More importantly, however, we are able to run much larger batches. Whereas the base model was unable to run with a concurrent batch size greater than 1 (as it was already running with an “effective” batch size of up to 512), the KV-cached model can concurrently execute dozens of tasks.

In fact, the limiting factor turns out not to be FLOPS but memory; my implementation instantiates the buffers with size \(B \times n_h \times L \times (d_m / n_h)\) for fixed constants \((B, L)\) representing the maximum batch size supported and the maximum sequence length supported. Because we have two such buffers for each of the \(n_\ell=48\) layers, and nanoGPT uses 4-byte fl32 for everything, the total memory occupied is \(2 n_\ell BLd_m\cdot 4 = 2\cdot 48\cdot 1600 \cdot 4 \cdot BL = 614400BL\) bytes. And if we set \(L\) to 512 — the minimum value required to execute a task) — we come to the unfortunate realization that our buffers take up \(0.29B\) gibibytes and that we can only fit a batch size of 32 on the GPU. Right?

Let’s say we could make all the values on the GPU take up half as much space as they did before. What should the new batch size be? Well, if the device could previously support a batch size of 32, we could argue heuristically that the parameters and buffers now take up just 12 GiB, and we can fit another 12 GiB / (0.15 GiB) = 80 buffers on the GPU for a total of 112.

We’d be wrong, though, because I’m also using the GPU to drive my 5K monitor, so it’s already got 1.75 GiB permanently set aside. Moreover, a larger batch size also increases the size of other tensors in the graph, which grow with the buffers. It all nets out to being able to support a batch size of 81, and I know this because actually halving the memory usage of all floats on the GPU to test this is the easy part:

# implementations/fp16.py, full text

from implementations.base import GPTConfig
from implementations.memoized import MemoizedGPT


class FP16MemoizedGPT(MemoizedGPT):
    def __init__(self, config: GPTConfig):
        super().__init__(config)
        self.half()

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        gpt = super().from_pretrained(*args, **kwargs)
        gpt.half()
        return gpt

Note that the FP16MemoizedGPT is able to execute 81 tasks in the ballpark of six seconds. That’s already about 150 times more throughput than the base implementation!

A natural question is whether we can make the jump from fp16 to fp8. Theoretically, this would be amazing for two reasons:

PyTorch doesn’t support fp8 yet, but Nvidia provides a Transformer Engine library6 that effectively acts as an fp8 extension to PyTorch. Transformer Engine is theoretically compatible with the Hopper and Ada GPU architectures, and it contains drop-in replacements for torch.nn.Linear and torch.nn.LayerNorm, as well as higher-level operators like LayerNormMLP and TransformerLayer. Installing it is a little tricky, as it needs to be built with alongside 11.8, which requires gcc and g++ version 11, which need to be installed separately and symlinked. And when the whole thing is all set up and ready, we receive the final disappointment: despite having promised fp8 inference on Ada, Nvidia is only delivering it in CUDA 12.1 in Q2. How horrible! We’ll have to quantize to int8 instead.


  1. Note that I follow the convention \(y = W(\begin{smallmatrix}x \\ [1]\end{smallmatrix})\) rather than \(y = xW + b\). 

  2. The UTF-8 edition from Project Gutenberg has 305k tokens if we strip out the metadata, the leading spaces, and the intra-paragraph newlines. The newline transformation is critical; we need to format the book to resemble OpenWebText data to get the best results. 

  3. Half are multiplications, the other half are additions. 

  4. I haven’t included the bias because I’m lazy. 

  5. Let’s say that rendering my desktop takes 1.75 GiB. Then 22.75 GiB remain for CUDA, consisting of parameters, buffers, and activations. The buffers and activations scale with the batch size and the weights do not, so if we let \(P\) be the total memory taken by the fp32 parameters and \(BA\) the memory taken by the fp32 buffers and activations, we have \(P + 32BA \approx P/2 + 81BA/2 \approx 22.5\), from which we may conclude that each batch element adds about 0.45 GiB of memory usage in the fp32 regime and about 0.11 GiB in the fp8 regime, for a theoretical batch size of around 136. 

  6. Some marketing executive at Nvidia decided to pitch the Transformer Engine as a hardware feature that comes “bundled” with the Hopper architecture — an odd sell, as the Transformer Engine now comes “bundled” with 4000-series consumer GPUs as well.