import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
torch.set_printoptions(precision=3, sci_mode=False)Self-Attention from First Principles
river_tokens = ["stream", "bank", "mud"]
finance_tokens = ["money", "bank", "loan"]We’ll assume that each token has a 4-dimensional embedding.
def embed(tokens):
vocab = {
"stream": torch.tensor([1.2, 0.0, 0.0, 0.3]),
"mud": torch.tensor([0.9, 0.0, 0.0, 0.9]),
"money": torch.tensor([0.0, 1.4, 0.0, 0.1]),
"loan": torch.tensor([0.0, 1.1, 0.0, 0.6]),
"bank": torch.tensor([0.8, 0.8, 0.2, 0.0]),
}
return torch.stack([vocab[token] for token in tokens])embed(river_tokens)tensor([[1.200, 0.000, 0.000, 0.300],
[0.800, 0.800, 0.200, 0.000],
[0.900, 0.000, 0.000, 0.900]])
embed(finance_tokens)tensor([[0.000, 1.400, 0.000, 0.100],
[0.800, 0.800, 0.200, 0.000],
[0.000, 1.100, 0.000, 0.600]])
We can see that the middle row for bank is identical in both contexts.
A fully connected layer that ignores structure has no built-in way to use neighboring information. So, again out of the layer is the same.
W = torch.tensor(
[
[1.0, 0.2, 0.0],
[0.1, 1.1, 0.0],
[0.5, 0.5, 1.0],
[0.0, 0.3, 0.8],
]
)embed(river_tokens) @ Wtensor([[1.200, 0.330, 0.240],
[0.980, 1.140, 0.200],
[0.900, 0.450, 0.720]])
embed(finance_tokens) @ Wtensor([[0.140, 1.570, 0.080],
[0.980, 1.140, 0.200],
[0.110, 1.390, 0.480]])
In the same way that CNNs help models understand images because locality matters and convolutions bake that assumption into the architecture, we can help models understand text by considering that context is often dynamic. The right word to look at can change from sentence to sentence, so we need a content-based way for tokens to interact.
If bank is going to mean different things in different contexts, its new representation has to depend on the other tokens in the sentence.
The smallest useful idea is:
- Compare
bankwith every token in the sequence. - Turn those comparisons into weights.
- Take a weighted combination of token vectors.
A raw dot product gives a quick similarity score: larger means “these two token vectors point in a more similar direction”.
river_embeddings = embed(river_tokens)
river_embeddings @ river_embeddings.Ttensor([[1.530, 0.960, 1.350],
[0.960, 1.320, 0.720],
[1.350, 0.720, 1.620]])
finance_embeddings = embed(finance_tokens)
finance_embeddings @ finance_embeddings.Ttensor([[1.970, 1.120, 1.600],
[1.120, 1.320, 0.880],
[1.600, 0.880, 1.570]])
Softmax turns those arbitrary scores into positive weights that sum to 1, so each token can take a weighted average over the sequence.
Now the representation of bank can change because the available context tokens have changed.
river_embeddings = embed(river_tokens)
river_scores = river_embeddings @ river_embeddings.T
river_weights = torch.softmax(river_scores, dim=-1)
river_weights @ river_embeddingstensor([[1.001, 0.188, 0.047, 0.438],
[0.949, 0.356, 0.089, 0.313],
[0.987, 0.150, 0.037, 0.520]])
finance_embeddings = embed(finance_tokens)
finance_scores = finance_embeddings @ finance_embeddings.T
finance_weights = torch.softmax(finance_scores, dim=-1)
finance_weights @ finance_embeddingstensor([[0.161, 1.181, 0.040, 0.243],
[0.325, 1.078, 0.081, 0.190],
[0.158, 1.163, 0.040, 0.278]])
The previous example used the raw embeddings for everything.
That is useful for intuition, but restrictive.
A learned attention layer should be able to decide:
- what features a token uses when it is looking for context
- what features a token exposes when it is being matched against
- what information should actually be passed along once it is attended to
That is why attention learns three projections:
- queries: what each token is looking for
- keys: what each token contains for matching
- values: what each token contributes to the weighted sum
W_query = torch.tensor(
[
[1.0, 0.0],
[0.0, 1.0],
[0.2, 0.2],
[0.0, 0.0],
]
)
W_key = torch.tensor(
[
[1.0, 0.0],
[0.0, 1.0],
[0.0, 0.0],
[0.1, 0.1],
]
)
W_value = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.5],
]
)river_queries = river_embeddings @ W_query
river_keys = river_embeddings @ W_key
river_values = river_embeddings @ W_value
river_scores = river_queries @ river_keys.T / math.sqrt(river_keys.shape[-1])
river_weights = torch.softmax(river_scores, dim=-1)
river_weights @ river_valuestensor([[0.992, 0.221, 0.261],
[0.957, 0.314, 0.256],
[0.986, 0.232, 0.263]])
finance_queries = finance_embeddings @ W_query
finance_keys = finance_embeddings @ W_key
finance_values = finance_embeddings @ W_value
finance_scores = finance_queries @ finance_keys.T / math.sqrt(finance_keys.shape[-1])
finance_weights = torch.softmax(finance_scores, dim=-1)
finance_weights @ finance_valuestensor([[0.188, 1.158, 0.169],
[0.297, 1.089, 0.180],
[0.204, 1.146, 0.172]])
Here is the recepie compacted into a class:
- project
xinto queries, keys, and values - score queries against keys
- normalize with softmax
- take a weighted sum of values
The only new detail below is the scaling by sqrt(d_k), which keeps dot-product scores from growing too large as the key dimension increases.
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_query = nn.Parameter(torch.rand(d_in, d_out))
self.W_key = nn.Parameter(torch.rand(d_in, d_out))
self.W_value = nn.Parameter(torch.rand(d_in, d_out))
def forward(self, x):
queries = x @ self.W_query
keys = x @ self.W_key
values = x @ self.W_value
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / math.sqrt(keys.shape[-1]), dim=-1)
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(0)
self_attention = SelfAttention(d_in=4, d_out=3)self_attention(river_embeddings)tensor([[0.540, 0.705, 1.030],
[0.538, 0.706, 1.030],
[0.541, 0.703, 1.025]], grad_fn=<MmBackward0>)
self_attention(finance_embeddings)tensor([[0.220, 0.418, 0.642],
[0.213, 0.404, 0.624],
[0.216, 0.409, 0.630]], grad_fn=<MmBackward0>)
A linear layer applies the same transformation to each token, but it has no mechanism for exchanging information across positions.
Self-attention still reuses the same learned parameters at every position, but the output for each token can change from one sequence to the next because the attention weights are computed from token-token interactions inside the sequence.