# Self-Attention from First Principles


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

``` python
import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)
```

``` python
river_tokens = ["stream", "bank", "mud"]
finance_tokens = ["money", "bank", "loan"]
```

We’ll assume that each token has a 4-dimensional embedding.

``` python
def embed(tokens):
    vocab = {
        "stream": torch.tensor([1.2, 0.0, 0.0, 0.3]),
        "mud": torch.tensor([0.9, 0.0, 0.0, 0.9]),
        "money": torch.tensor([0.0, 1.4, 0.0, 0.1]),
        "loan": torch.tensor([0.0, 1.1, 0.0, 0.6]),
        "bank": torch.tensor([0.8, 0.8, 0.2, 0.0]),
    }
    return torch.stack([vocab[token] for token in tokens])
```

``` python
embed(river_tokens)
```

    tensor([[1.200, 0.000, 0.000, 0.300],
            [0.800, 0.800, 0.200, 0.000],
            [0.900, 0.000, 0.000, 0.900]])

``` python
embed(finance_tokens)
```

    tensor([[0.000, 1.400, 0.000, 0.100],
            [0.800, 0.800, 0.200, 0.000],
            [0.000, 1.100, 0.000, 0.600]])

We can see that the middle row for `bank` is identical in both contexts.

A fully connected layer that ignores structure has no built-in way to
use neighboring information. So, again out of the layer is the same.

``` python
W = torch.tensor(
    [
        [1.0, 0.2, 0.0],
        [0.1, 1.1, 0.0],
        [0.5, 0.5, 1.0],
        [0.0, 0.3, 0.8],
    ]
)
```

``` python
embed(river_tokens) @ W
```

    tensor([[1.200, 0.330, 0.240],
            [0.980, 1.140, 0.200],
            [0.900, 0.450, 0.720]])

``` python
embed(finance_tokens) @ W
```

    tensor([[0.140, 1.570, 0.080],
            [0.980, 1.140, 0.200],
            [0.110, 1.390, 0.480]])

In the same way that CNNs help models understand images because locality
matters and convolutions bake that assumption into the architecture, we
can help models understand text by considering that context is often
dynamic. The right word to look at can change from sentence to sentence,
so we need a content-based way for tokens to interact.

If `bank` is going to mean different things in different contexts, its
new representation has to depend on the other tokens in the sentence.

The smallest useful idea is:

1.  Compare `bank` with every token in the sequence.
2.  Turn those comparisons into weights.
3.  Take a weighted combination of token vectors.

A raw dot product gives a quick similarity score: larger means “these
two token vectors point in a more similar direction”.

``` python
river_embeddings = embed(river_tokens)
river_embeddings @ river_embeddings.T
```

    tensor([[1.530, 0.960, 1.350],
            [0.960, 1.320, 0.720],
            [1.350, 0.720, 1.620]])

``` python
finance_embeddings = embed(finance_tokens)
finance_embeddings @ finance_embeddings.T
```

    tensor([[1.970, 1.120, 1.600],
            [1.120, 1.320, 0.880],
            [1.600, 0.880, 1.570]])

Softmax turns those arbitrary scores into positive weights that sum to
1, so each token can take a weighted average over the sequence.

Now the representation of `bank` can change because the available
context tokens have changed.

``` python
river_embeddings = embed(river_tokens)
river_scores = river_embeddings @ river_embeddings.T
river_weights = torch.softmax(river_scores, dim=-1)
river_weights @ river_embeddings
```

    tensor([[1.001, 0.188, 0.047, 0.438],
            [0.949, 0.356, 0.089, 0.313],
            [0.987, 0.150, 0.037, 0.520]])

``` python
finance_embeddings = embed(finance_tokens)
finance_scores = finance_embeddings @ finance_embeddings.T
finance_weights = torch.softmax(finance_scores, dim=-1)
finance_weights @ finance_embeddings
```

    tensor([[0.161, 1.181, 0.040, 0.243],
            [0.325, 1.078, 0.081, 0.190],
            [0.158, 1.163, 0.040, 0.278]])

The previous example used the raw embeddings for everything.

That is useful for intuition, but restrictive.

A learned attention layer should be able to decide:

- what features a token uses when it is *looking for* context
- what features a token exposes when it is being *matched against*
- what information should actually be *passed along* once it is attended
  to

That is why attention learns three projections:

- queries: what each token is looking for
- keys: what each token contains for matching
- values: what each token contributes to the weighted sum

``` python
W_query = torch.tensor(
    [
        [1.0, 0.0],
        [0.0, 1.0],
        [0.2, 0.2],
        [0.0, 0.0],
    ]
)
W_key = torch.tensor(
    [
        [1.0, 0.0],
        [0.0, 1.0],
        [0.0, 0.0],
        [0.1, 0.1],
    ]
)
W_value = torch.tensor(
    [
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 0.5],
    ]
)
```

``` python
river_queries = river_embeddings @ W_query
river_keys = river_embeddings @ W_key
river_values = river_embeddings @ W_value

river_scores = river_queries @ river_keys.T / math.sqrt(river_keys.shape[-1])
river_weights = torch.softmax(river_scores, dim=-1)
river_weights @ river_values
```

    tensor([[0.992, 0.221, 0.261],
            [0.957, 0.314, 0.256],
            [0.986, 0.232, 0.263]])

``` python
finance_queries = finance_embeddings @ W_query
finance_keys = finance_embeddings @ W_key
finance_values = finance_embeddings @ W_value

finance_scores = finance_queries @ finance_keys.T / math.sqrt(finance_keys.shape[-1])
finance_weights = torch.softmax(finance_scores, dim=-1)
finance_weights @ finance_values
```

    tensor([[0.188, 1.158, 0.169],
            [0.297, 1.089, 0.180],
            [0.204, 1.146, 0.172]])

Here is the recepie compacted into a class:

1.  project `x` into queries, keys, and values
2.  score queries against keys
3.  normalize with softmax
4.  take a weighted sum of values

The only new detail below is the scaling by `sqrt(d_k)`, which keeps
dot-product scores from growing too large as the key dimension
increases.

``` python
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / math.sqrt(keys.shape[-1]), dim=-1)
        context_vec = attn_weights @ values
        return context_vec


torch.manual_seed(0)
self_attention = SelfAttention(d_in=4, d_out=3)
```

``` python
self_attention(river_embeddings)
```

    tensor([[0.540, 0.705, 1.030],
            [0.538, 0.706, 1.030],
            [0.541, 0.703, 1.025]], grad_fn=<MmBackward0>)

``` python
self_attention(finance_embeddings)
```

    tensor([[0.220, 0.418, 0.642],
            [0.213, 0.404, 0.624],
            [0.216, 0.409, 0.630]], grad_fn=<MmBackward0>)

A linear layer applies the same transformation to each token, but it has
no mechanism for exchanging information across positions.

Self-attention still reuses the same learned parameters at every
position, but the output for each token can change from one sequence to
the next because the attention weights are computed from token-token
interactions inside the sequence.
