jcreed blog > Attention for Type Theorists

Attention for Type Theorists

I don't think I have enough patience or motivation to actually jump on the big NN bandwagon and actually get any GPU computations working, but I was curious about what exactly "attention mechanisms" were --- i.e. what Attention is All You Need is even talking about.

Because my general feeling about neural nets is: well, I understand gradient descent. If you can tell me what function you're computing from what inputs and what parameters, then maybe I'm slow at actually reinventing backprop from scratch, but I'm confident that I could reinvent backprop from scratch if I needed to. And if I'm lazy I can just ask pytorch to do it for me.

What I really wanted was a compact description of

  1. What is the type of inputs?
  2. What is the type of outputs?
  3. What is the function that computes outputs from inputs and parameters?

So what is an attention block?

You the NN-architecture-designer get to choose numbers $c, d, m : \N$. Then:
  1. What is the type of inputs?
    It is the record type \[ \{ q : \R^d, k : \R^{d \x m}, v : \R^{c \x m} \} \] In more detail, this record type consists of:
  2. What is the type of outputs?
    It is $\R^c$.
  3. What is the function that computes from inputs to outputs?
    It's essentially a softmax weighted sum over all the "values" where the weights are dot products between each "key" and the "query". \[\mathsf{compute} : \{ q : \R^a, k : \R^{b \x m}, v : \R^{c \x m} \} \to \R^c \] \[ \mathsf{compute} \{q, k, v\} = {\sum_{i=1}^m e^{q\cdot k_i} v_i \over \sum_{i=1}^m e^{q\cdot k_i}}\]

Wait, Where are the Parameters?

When I refer to "query" and "keys" and "values" above, I'm referring to things that are already in the appropriate linear spaces where I can take dot products of $q$ against each $k_i$, and use them as the softmax weights.

In practice, (as far as I understand!) these inputs aren't wired up directly to other parts of the neural net, but rather you throw a whole matrix worth of parameters in front of them: the "raw" data is in some other linear space, and the parameters say how to smash them into the right linear space to do the dot products.

By doing this we let the network learn what "queries" it should be making, and what "keys" and "values" other data should produce.

What's the Intuition for How This is Used?

Any ol' differentiable function could be chucked into the neural net stew. Why is this one useful in practice?

Here I'm less confident I understand what's going on, but Karpathy's videos (especially "Let's build GPT") gave me some sense of it.

I think the important thing is that the transformer architecture is an alternative to RNNs: instead of your forward function being a whole lot of recursive applications of the net to its own hidden-state output, you put some attention units in, and let individual tokens in your stream dictate (via "query") what sort of past data they're interested in, and you let the past data itself (via "key") figure out how to advertise itself as interesting to future data, and (via "value") how to contribute some signal to be used by it.

None of anything above says anything about position of tokens in a stream: I understand trigonometric positional encodings are typically concatenated on to feature vectors so that the computation of query/key/value can depend usefully on that.