Porting microgpt to Futhark, Part I

The forward pass

I have been wanting to find a project to try out the data-parallel language Futhark. They have a very good blog that I've been following for years, but I've never actually written anything in it.

Andrej Karpathy's microgpt, a self-contained implementation of a GPT-2-like neural network in 200 lines of Python, finally provided the excuse. I like microgpt, but it does not scale at all. Obviously the point of this implementation is not efficiency, but it's not just that it's slow: you also can't scale up to even slightly larger networks, because you quickly hit Python recursion depth errors.

So, I was curious whether I could port it as 1-to-1 as possible and get much better scaling without losing too much concision. The answer, as it turns out, is sort-of: the port scales much better but is not as concise. Parts of it translate quite nicely though.

This post, Part I, will start with just the forward pass. I'll alternate code from Karpathy's original Python version with my Futhark translation, attempting to keep things as similar as possible, while still taking advantage of Futhark's parallel primitives.

LLM parameters

First, the data structures holding the LLM parameters (weights). We will assume these are pretrained (the training code will come in Part II).

Python:

n_layer = 1
n_embd = 16
block_size = 16
n_head = 4
head_dim = n_embd // n_head

matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)]
state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)}
for i in range(n_layer):
    state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd)
    state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd)
    state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd)
    state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd)
    state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd)
    state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd)
params = [p for mat in state_dict.values() for row in mat for p in row]

Futhark:

def n_layer : i64 = 1
def n_embd : i64 = 16
def block_size : i64 = 16
def n_head : i64 = 4
def head_dim : i64 = n_embd / n_head

type params [v] = {
  wte:      [v][n_embd]f32,                   -- token embeddings
  wpe:      [block_size][n_embd]f32,          -- position embeddings
  lm_head:  [v][n_embd]f32,                   -- output projection
  attn_wq:  [n_layer][n_embd][n_embd]f32,     -- query weights
  attn_wk:  [n_layer][n_embd][n_embd]f32,     -- key weights
  attn_wv:  [n_layer][n_embd][n_embd]f32,     -- value weights
  attn_wo:  [n_layer][n_embd][n_embd]f32,     -- output weights
  mlp_fc1:  [n_layer][4 * n_embd][n_embd]f32, -- MLP up-projection
  mlp_fc2:  [n_layer][n_embd][4 * n_embd]f32  -- MLP down-projection
}

Model components

Next, the basic components of the particular model architecture Karpathy chose.

Python:

def linear(x, w):
    return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]

def softmax(logits):
    max_val = max(val.data for val in logits)
    exps = [(val - max_val).exp() for val in logits]
    total = sum(exps)
    return [e / total for e in exps]

def rmsnorm(x):
    ms = sum(xi * xi for xi in x) / len(x)
    scale = (ms + 1e-5) ** -0.5
    return [xi * scale for xi in x]

Futhark:

def linear [n][m] (x: [n]f32) (w: [m][n]f32) : [m]f32 =
  map (\w_row -> reduce (+) 0f32 (map2 (*) w_row x)) w

def softmax [n] (logits: [n]f32) : [n]f32 =
  let max_val = reduce f32.max f32.lowest logits
  let exps = map (\v -> f32.exp (v - max_val)) logits
  let total = reduce (+) 0f32 exps
  in map (/ total) exps

def rmsnorm [n] (x: [n]f32) : [n]f32 =
  let ms = reduce (+) 0f32 (map (\xi -> xi * xi) x) / f32.i64 n
  let scale = 1f32 / f32.sqrt (ms + 1e-5)
  in map (* scale) x

I'm pleased at how nicely these three functions translate. The explicit typing does add a little syntax noise. Arguably reduce (+) 0f32 is also not as nice a way to spell sum. But it's generally readable, especially if you are already familiar with these kinds of functional combinators. The number of lines of code stayed exactly the same.

GPT forward pass

And finally, the GPT forward pass, complete with a KV cache.

Karpathy's Python original:

def gpt(token_id, pos_id, keys, values):
    tok_emb = state_dict['wte'][token_id]         # token embedding
    pos_emb = state_dict['wpe'][pos_id]           # position embedding
    x = [t + p for t, p in zip(tok_emb, pos_emb)] # joint token and position embedding
    x = rmsnorm(x) # note: not redundant due to backward pass via the residual connection

    for li in range(n_layer):
        # 1) Multi-head Attention block
        x_residual = x
        x = rmsnorm(x)
        q = linear(x, state_dict[f'layer{li}.attn_wq'])
        k = linear(x, state_dict[f'layer{li}.attn_wk'])
        v = linear(x, state_dict[f'layer{li}.attn_wv'])
        keys[li].append(k)
        values[li].append(v)
        x_attn = []
        for h in range(n_head):
            hs = h * head_dim
            q_h = q[hs:hs+head_dim]
            k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
            v_h = [vi[hs:hs+head_dim] for vi in values[li]]
            attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
            attn_weights = softmax(attn_logits)
            head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
            x_attn.extend(head_out)
        x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
        x = [a + b for a, b in zip(x, x_residual)]
        # 2) MLP block
        x_residual = x
        x = rmsnorm(x)
        x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
        x = [xi.relu() for xi in x]
        x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
        x = [a + b for a, b in zip(x, x_residual)]

    logits = linear(x, state_dict['lm_head'])
    return logits

My Futhark port:

def gpt [v]
  (p: params [v])
  (token_id: i64) (pos_id: i64)
  (keys:   *[n_layer][block_size][n_embd]f32)
  (values: *[n_layer][block_size][n_embd]f32)
  : ([v]f32,
     *[n_layer][block_size][n_embd]f32,
     *[n_layer][block_size][n_embd]f32) =

  let tok_emb = p.wte[token_id]     -- token embedding
  let pos_emb = p.wpe[pos_id]       -- position embedding
  let x = map2 (+) tok_emb pos_emb  -- joint token and position embedding
  let x = rmsnorm x

  let (x, keys, values) =
    loop (x, keys, values) for li < n_layer do
      -- 1) Multi-head Attention block
      let x_residual = x
      let x_norm = rmsnorm x
      let q     = linear x_norm p.attn_wq[li]
      let k     = linear x_norm p.attn_wk[li]
      let v_vec = linear x_norm p.attn_wv[li]
      let keys   = keys   with [li, pos_id] = k
      let values = values with [li, pos_id] = v_vec
      let x_attn = flatten (
        tabulate n_head (\h ->
          let hs = h * head_dim
          let q_h = tabulate head_dim (\j -> q[hs + j])
          let scale = 1f32 / f32.sqrt (f32.i64 head_dim)
          let attn_logits = tabulate block_size (\t ->
            let dot = reduce (+) 0f32 (
              tabulate head_dim (\j -> q_h[j] * keys[li, t, hs + j])
            )
            in if t <= pos_id then dot * scale else -1e30f32
          )
          let attn_weights = softmax attn_logits
          in tabulate head_dim (\j ->
            reduce (+) 0f32 (
              tabulate block_size (\t -> attn_weights[t] * values[li, t, hs + j])
            )
          )
        )
      ) :> [n_embd]f32
      let x_out = linear x_attn p.attn_wo[li]
      let x = map2 (+) x_out x_residual

      -- 2) MLP block
      let x_residual = x
      let x_norm = rmsnorm x
      let x_mlp = linear x_norm p.mlp_fc1[li]
      let x_mlp = map (f32.max 0) x_mlp
      let x_mlp = linear x_mlp p.mlp_fc2[li]
      let x = map2 (+) x_mlp x_residual
      in (x, keys, values)
  let logits = linear x p.lm_head
  in (logits, keys, values)

The for loops mostly translate to tabulate, which you can think of as essentially a parallel for loop. The exception is the outer loop over layers, which needs to stay sequential, so becomes loop. The MLP block at the end translates pretty directly too. The attention block was a bit hairier to wrangle into Futhark's constraints, due to using some imperative Python features and destructively updated data structures. But it was not too bad. The main change was to preallocate the KV cache in a fixed-size array (size [n_layer][block_size][n_embd]), and then, in the attn_logits calculation, mask out "future" tokens to keep the model causal (not needed in the Python version because the list was constructed in causal order).

The total lines of code for this function (excluding comments and blank lines) crept up from 33 to 51, but partly because I broke statements across lines more liberally than the original did.

I will admit that, even as someone who likes this style of functional programming, the end result is arguably less readable. The deep nesting in particular is a bit hard to follow: once you're inside a lambda inside a tabulate inside a reduce inside a tabulate inside another tabulate, inside a flatten, it can be easy to lose track of what's going on. This could probably be refactored to be more readable, but for now I stuck to as close a translation as possible, so since Karpathy had a bunch of list comprehensions inside of nested loops, I kept the same structure.

There is also one minor annoyance to please the size-typing system: Futhark infers that the result of flatten is a size n_head * head_dim 1d array (because we flattened an array of size [n_head][head_dim]), but it isn't able to further infer that this is same as a size n_embd 1d array. So we need to use the :> size coercion operator. On the other hand, there is a minor readability improvement: map2, which maps a 2-parameter function elementwise across two lists, is more intuitive imo than the zip version.

* * *

Missing here, of course, is the star of the show: training the model! That will come in Part II, along with some benchmarks.