Self-Attention from First Principles

import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)
river_tokens = ["stream", "bank", "mud"]
finance_tokens = ["money", "bank", "loan"]

We’ll assume that each token has a 4-dimensional embedding.

def embed(tokens):
    vocab = {
        "stream": torch.tensor([1.2, 0.0, 0.0, 0.3]),
        "mud": torch.tensor([0.9, 0.0, 0.0, 0.9]),
        "money": torch.tensor([0.0, 1.4, 0.0, 0.1]),
        "loan": torch.tensor([0.0, 1.1, 0.0, 0.6]),
        "bank": torch.tensor([0.8, 0.8, 0.2, 0.0]),
    }
    return torch.stack([vocab[token] for token in tokens])
embed(river_tokens)
tensor([[1.200, 0.000, 0.000, 0.300],
        [0.800, 0.800, 0.200, 0.000],
        [0.900, 0.000, 0.000, 0.900]])
embed(finance_tokens)
tensor([[0.000, 1.400, 0.000, 0.100],
        [0.800, 0.800, 0.200, 0.000],
        [0.000, 1.100, 0.000, 0.600]])

We can see that the middle row for bank is identical in both contexts.

A fully connected layer that ignores structure has no built-in way to use neighboring information. So, again out of the layer is the same.

W = torch.tensor(
    [
        [1.0, 0.2, 0.0],
        [0.1, 1.1, 0.0],
        [0.5, 0.5, 1.0],
        [0.0, 0.3, 0.8],
    ]
)
embed(river_tokens) @ W
tensor([[1.200, 0.330, 0.240],
        [0.980, 1.140, 0.200],
        [0.900, 0.450, 0.720]])
embed(finance_tokens) @ W
tensor([[0.140, 1.570, 0.080],
        [0.980, 1.140, 0.200],
        [0.110, 1.390, 0.480]])

In the same way that CNNs help models understand images because locality matters and convolutions bake that assumption into the architecture, we can help models understand text by considering that context is often dynamic. The right word to look at can change from sentence to sentence, so we need a content-based way for tokens to interact.

If bank is going to mean different things in different contexts, its new representation has to depend on the other tokens in the sentence.

The smallest useful idea is:

  1. Compare bank with every token in the sequence.
  2. Turn those comparisons into weights.
  3. Take a weighted combination of token vectors.

A raw dot product gives a quick similarity score: larger means “these two token vectors point in a more similar direction”.

river_embeddings = embed(river_tokens)
river_embeddings @ river_embeddings.T
tensor([[1.530, 0.960, 1.350],
        [0.960, 1.320, 0.720],
        [1.350, 0.720, 1.620]])
finance_embeddings = embed(finance_tokens)
finance_embeddings @ finance_embeddings.T
tensor([[1.970, 1.120, 1.600],
        [1.120, 1.320, 0.880],
        [1.600, 0.880, 1.570]])

Softmax turns those arbitrary scores into positive weights that sum to 1, so each token can take a weighted average over the sequence.

Now the representation of bank can change because the available context tokens have changed.

river_embeddings = embed(river_tokens)
river_scores = river_embeddings @ river_embeddings.T
river_weights = torch.softmax(river_scores, dim=-1)
river_weights @ river_embeddings
tensor([[1.001, 0.188, 0.047, 0.438],
        [0.949, 0.356, 0.089, 0.313],
        [0.987, 0.150, 0.037, 0.520]])
finance_embeddings = embed(finance_tokens)
finance_scores = finance_embeddings @ finance_embeddings.T
finance_weights = torch.softmax(finance_scores, dim=-1)
finance_weights @ finance_embeddings
tensor([[0.161, 1.181, 0.040, 0.243],
        [0.325, 1.078, 0.081, 0.190],
        [0.158, 1.163, 0.040, 0.278]])

The previous example used the raw embeddings for everything.

That is useful for intuition, but restrictive.

A learned attention layer should be able to decide:

That is why attention learns three projections:

W_query = torch.tensor(
    [
        [1.0, 0.0],
        [0.0, 1.0],
        [0.2, 0.2],
        [0.0, 0.0],
    ]
)
W_key = torch.tensor(
    [
        [1.0, 0.0],
        [0.0, 1.0],
        [0.0, 0.0],
        [0.1, 0.1],
    ]
)
W_value = torch.tensor(
    [
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 0.5],
    ]
)
river_queries = river_embeddings @ W_query
river_keys = river_embeddings @ W_key
river_values = river_embeddings @ W_value

river_scores = river_queries @ river_keys.T / math.sqrt(river_keys.shape[-1])
river_weights = torch.softmax(river_scores, dim=-1)
river_weights @ river_values
tensor([[0.992, 0.221, 0.261],
        [0.957, 0.314, 0.256],
        [0.986, 0.232, 0.263]])
finance_queries = finance_embeddings @ W_query
finance_keys = finance_embeddings @ W_key
finance_values = finance_embeddings @ W_value

finance_scores = finance_queries @ finance_keys.T / math.sqrt(finance_keys.shape[-1])
finance_weights = torch.softmax(finance_scores, dim=-1)
finance_weights @ finance_values
tensor([[0.188, 1.158, 0.169],
        [0.297, 1.089, 0.180],
        [0.204, 1.146, 0.172]])

Here is the recepie compacted into a class:

  1. project x into queries, keys, and values
  2. score queries against keys
  3. normalize with softmax
  4. take a weighted sum of values

The only new detail below is the scaling by sqrt(d_k), which keeps dot-product scores from growing too large as the key dimension increases.

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / math.sqrt(keys.shape[-1]), dim=-1)
        context_vec = attn_weights @ values
        return context_vec


torch.manual_seed(0)
self_attention = SelfAttention(d_in=4, d_out=3)
self_attention(river_embeddings)
tensor([[0.540, 0.705, 1.030],
        [0.538, 0.706, 1.030],
        [0.541, 0.703, 1.025]], grad_fn=<MmBackward0>)
self_attention(finance_embeddings)
tensor([[0.220, 0.418, 0.642],
        [0.213, 0.404, 0.624],
        [0.216, 0.409, 0.630]], grad_fn=<MmBackward0>)

A linear layer applies the same transformation to each token, but it has no mechanism for exchanging information across positions.

Self-attention still reuses the same learned parameters at every position, but the output for each token can change from one sequence to the next because the attention weights are computed from token-token interactions inside the sequence.