引言
2017 年,Google 团队在论文《Attention Is All You Need》中提出了 Transformer 架构,彻底改变了自然语言处理领域。Transformer 完全基于注意力机制(Attention Mechanism),摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN)结构,在序列建模任务上取得了突破性进展。
本文将深入讲解 Transformer 中的核心——注意力机制,包括其数学原理、实现细节以及可视化分析。
什么是注意力机制?
注意力机制源于对人类视觉注意力的模拟。当人类观察复杂场景时,会快速扫描全局,然后聚焦于关键区域,忽略无关信息。类似地,在 NLP 任务中,注意力机制让模型学会”关注”输入序列中的重要部分。
传统序列模型的局限
在 Transformer 出现之前,RNN/LSTM 是处理序列数据的主流方法:
RNN 的问题:
- 无法并行化:必须按时间步顺序计算,训练速度慢
- 长距离依赖困难:虽然 LSTM 有所改善,但仍有梯度消失问题
- 信息瓶颈:编码器将整个序列压缩成固定长度的向量,丢失信息
示例: 翻译句子 “The animal didn’t cross the street because it was too tired”
- 人类知道 “it” 指代 “animal”
- RNN 需要记住很远的词义,容易遗忘
- Attention 让模型直接关注 “animal”,建立直接联系
Attention 的核心思想
注意力机制通过计算查询(Query)与键(Key)的相似度,得到注意力权重,然后对**值(Value)**进行加权求和。
用数据库操作类比:
- Query = 搜索关键词
- Key = 索引/标签
- Value = 实际数据内容
模型根据 Query 找到相关的 Key,然后提取对应的 Value 信息。
Scaled Dot-Product Attention
Transformer 使用缩放点积注意力,这是最基本的注意力单元。
数学公式
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
其中:
- $Q \in \mathbb{R}^{n \times d_k}$ - 查询矩阵
- $K \in \mathbb{R}^{m \times d_k}$ - 键矩阵
- $V \in \mathbb{R}^{m \times d_v}$ - 值矩阵
- $d_k$ - 键的维度
- $d_v$ - 值的维度
- $\sqrt{d_k}$ - 缩放因子
计算步骤详解
步骤 1:计算相似度分数
通过点积计算 Query 和 Key 的相似度:
$$S = QK^T$$
$S_{ij}$ 表示第 $i$ 个 Query 与第 $j$ 个 Key 的匹配程度。
步骤 2:缩放
除以 $\sqrt{d_k}$ 进行缩放:
$$S’ = \frac{S}{\sqrt{d_k}}$$
为什么需要缩放?
当 $d_k$ 较大时,点积结果会很大,导致 softmax 进入梯度极小的饱和区(one-hot 分布)。缩放可以保持梯度稳定。
数学解释:
- 假设 $q$ 和 $k$ 的元素独立且均值为 0、方差为 1
- 点积 $q \cdot k = \sum_{i=1}^{d_k} q_i k_i$ 的均值为 0、方差为 $d_k$
- 除以 $\sqrt{d_k}$ 后,方差重新变为 1
步骤 3:Softmax 归一化
应用 softmax 函数,将分数转换为概率分布:
$$A = \text{softmax}(S’)$$
$$A_{ij} = \frac{\exp(S’{ij})}{\sum{l=1}^{m} \exp(S’_{il})}$$
$A_{ij}$ 表示在生成第 $i$ 个输出时,对第 $j$ 个输入的关注程度(权重)。
步骤 4:加权求和
使用注意力权重对 Value 进行加权求和:
$$O = AV$$
最终输出 $O$ 包含了所有位置的信息,但重要位置的贡献更大。
可视化示意
注意力权重矩阵热力图:
K1 K2 K3 K4 K5
┌─────────────────────────────┐
Q1 │ 0.1 0.1 0.6 0.1 0.1 │ → 关注 K3
Q2 │ 0.05 0.05 0.1 0.7 0.1 │ → 关注 K4
Q3 │ 0.2 0.5 0.1 0.1 0.1 │ → 关注 K2
└─────────────────────────────┘
↓ ↓ ↓ ↓ ↓
V1 V2 V3 V4 V5
输出 = 加权和 (如 Q1 ≈ 0.6×V3)
Multi-Head Attention
多头注意力是 Transformer 的核心创新之一,它允许模型同时关注不同子空间的信息。
为什么要 Multi-Head?
单个注意力头只能学习一种注意力模式,而 Multi-Head 可以:
捕捉多种关系:不同头关注不同类型的依赖
- 头 1:语法关系(主谓一致)
- 头 2:语义关系(指代消解)
- 头 3:位置关系(相邻词)
增强表达能力:多个子空间并行学习,提供更丰富的特征表示
提高鲁棒性:多个头的集成效果,避免单点失败
结构图
Input: Q, K, V
│
▼
┌────┴────┐
│ Linear │ 投影到 h 个子空间
└────┬────┘
│
┌────┼────┬──────────┐
│ │ │ │
▼ ▼ ▼ ▼
┌───────┐┌───────┐┌───────┐┌───────┐
│Head 1 ││Head 2 ││ ... ││Head h │ 并行计算注意力
└───┬───┘└───┬───┘└───┬───┘└───┬───┘
│ │ │ │
└────────┴────────┴────────┘
│
▼
Concatenate 拼接所有头
│
▼
Linear 最终投影
│
▼
Output
数学表达
对于每个头 $i$:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
其中 $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$,$W_i^K \in \mathbb{R}^{d_{model} \times d_k}$,$W_i^V \in \mathbb{R}^{d_{model} \times d_v}$ 是投影矩阵。
拼接所有头并投影:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$
其中 $W^O \in \mathbb{R}^{h \cdot d_v \times d_{model}}$ 是输出投影矩阵。
典型配置
在原始 Transformer 论文中:
- $d_{model} = 512$(模型维度)
- $h = 8$(8 个头)
- $d_k = d_v = 64$(每个头 64 维)
- $8 \times 64 = 512$(总维度保持一致)
Self-Attention(自注意力)
Self-Attention 是 Multi-Head Attention 的特例,其中 $Q = K = V$ 都来自同一个输入序列。
工作原理
给定输入序列 $X = (x_1, x_2, \ldots, x_n)$:
线性投影
$$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$计算注意力
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$输出
每个位置的输出是所有位置信息的加权组合
示例分析
句子:”The cat sat on the mat because it was comfortable”
计算 “it” 的 Self-Attention:
输入:[The, cat, sat, on, the, mat, because, it, was, comfortable]
"it" 的注意力权重分布:
The: 0.02
cat: 0.25 ← 指代关系
sat: 0.05
on: 0.03
the: 0.02
mat: 0.15
because: 0.08
it: 0.10 ← 自身
was: 0.05
comfortable: 0.25 ← 因果关系
输出 = 0.25×Embedding(cat) + 0.25×Embedding(comfortable) + ...
可视化注意力权重:
"it" → 关注的主要词:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
The ▏▎ 2%
cat ▓▓▓▓▓▓▓▓ 25% ★
sat ▏▎ 5%
on ▏▎ 3%
the ▏▎ 2%
mat ▓▓▓▓ 15%
because ▓▓ 8%
it ▓▓▓ 10%
was ▏▎ 5%
comfortable ▓▓▓▓▓▓▓▓ 25% ★
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
这表明模型学会了:
- 指代消解:”it” 指代 “cat”
- 因果推理:”comfortable” 解释了原因
Positional Encoding(位置编码)
Self-Attention 是排列不变的(permutation-equivariant),即打乱输入顺序不会改变注意力权重。为了利用序列的顺序信息,Transformer 引入了位置编码。
为什么需要位置编码?
- RNN 天然具有顺序性(按时间步处理)
- Attention 没有内置的位置概念
- 需要显式注入位置信息
正弦余弦位置编码
原始论文使用三角函数生成位置编码:
$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
其中:
- $pos$ 是位置(0, 1, 2, …)
- $i$ 是维度索引(0, 1, …, $d_{model}/2 - 1$)
设计原理
优点:
- 唯一性:每个位置有独特的编码
- 相对位置感知:对于固定偏移量 $k$,$PE_{pos+k}$ 可以表示为 $PE_{pos}$ 的线性函数
- 外推能力:可以处理比训练时更长的序列
可视化位置编码波形:
维度 0-63: 高频正弦波(短波长)
维度 64-127: 中频正弦波
维度 128-191:低频正弦波
维度 192-255:极低频(接近常数)
不同维度使用不同频率的正余弦波
→ 形成唯一的"位置指纹"
位置编码的实现
import numpy as np
def get_positional_encoding(position, d_model):
"""
生成正弦余弦位置编码
Args:
position: 位置索引 (sequence_length,)
d_model: 模型维度
Returns:
pe: 位置编码 (position, d_model)
"""
pe = np.zeros((position, d_model))
pos = np.arange(position)[:, np.newaxis] # (position, 1)
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = np.sin(pos * div_term) # 偶数维度用 sin
pe[:, 1::2] = np.cos(pos * div_term) # 奇数维度用 cos
return pe
# 示例
pe = get_positional_encoding(100, 512)
print(f"位置编码形状:{pe.shape}") # (100, 512)
可学习的位置编码
除了固定的正弦编码,也可以使用可学习的 Embedding:
class LearnablePositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.pe = nn.Embedding(max_len, d_model)
# pe.weight 会在训练中自动学习
def forward(self, x):
positions = torch.arange(x.size(1)).unsqueeze(0)
return x + self.pe(positions)
Transformer Encoder 中的 Attention
完整的 Transformer Encoder 包含多层 Multi-Head Self-Attention。
Encoder 结构
输入 Embedding + 位置编码
│
▼
┌─────────────────────────┐
│ Multi-Head Attention │ 自注意力层
│ (Q=K=V=输入) │
│ │
│ Add & Norm │ 残差连接 + 层归一化
│ │
│ Feed Forward │ 位置前馈网络
│ (两层 MLP + ReLU) │
│ │
│ Add & Norm │ 残差连接 + 层归一化
└─────────────────────────┘
│
▼
重复 N 次 (N=6)
│
▼
Encoder 输出
层归一化(Layer Normalization)
在每个子层之后使用 LayerNorm:
$$\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sigma} + \beta$$
其中 $\mu$ 和 $\sigma$ 是均值和标准差,$\gamma$ 和 $\beta$ 是可学习参数。
残差连接
$$\text{Output} = \text{LayerNorm}(x + \text{Sublayer}(x))$$
残差连接帮助梯度流动,允许训练更深的网络。
Transformer Decoder 中的 Attention
Decoder 包含三种注意力机制:
1. Masked Multi-Head Self-Attention
在训练时防止看到未来信息(只能关注当前位置之前的词)。
Mask 矩阵:
t1 t2 t3 t4
t1 [ 0 -∞ -∞ -∞ ] ← t1 只能看到 t1
t2 [ 0 0 -∞ -∞ ] ← t2 可以看到 t1,t2
t3 [ 0 0 0 -∞ ] ← t3 可以看到 t1,t2,t3
t4 [ 0 0 0 0 ] ← t4 可以看到全部
注意力计算:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$$
其中 $M$ 是上三角掩码矩阵(未来位置为 $-\infty$)。
2. Cross-Attention(交叉注意力)
连接 Encoder 和 Decoder 的桥梁:
- Query 来自 Decoder
- Key, Value 来自 Encoder
Decoder 输入 (Query) ───┐
├──→ Attention → Decoder 下一层
Encoder 输出 (K,V) ──────┘
这允许 Decoder 在生成每个词时,关注输入序列的相关部分。
3. Feed Forward Network
与 Encoder 相同的两层 MLP。
Python 实现
完整实现 Scaled Dot-Product Attention
import numpy as np
def softmax(x, axis=-1):
"""数值稳定的 softmax"""
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / e_x.sum(axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力
Args:
Q: 查询矩阵 (seq_len_q, d_k)
K: 键矩阵 (seq_len_k, d_k)
V: 值矩阵 (seq_len_v, d_v)
mask: 可选掩码
Returns:
output: 注意力输出 (seq_len_q, d_v)
attention_weights: 注意力权重 (seq_len_q, seq_len_k)
"""
d_k = Q.shape[-1]
# 1. 计算相似度分数
scores = np.dot(Q, K.T) # (seq_len_q, seq_len_k)
# 2. 缩放
scores = scores / np.sqrt(d_k)
# 3. 应用掩码(如果有)
if mask is not None:
scores = scores + (mask * -1e9)
# 4. Softmax 归一化
attention_weights = softmax(scores, axis=-1)
# 5. 加权求和
output = np.dot(attention_weights, V)
return output, attention_weights
# 测试
d_k = 64
d_v = 64
seq_len_q = 5
seq_len_k = 5
Q = np.random.randn(seq_len_q, d_k)
K = np.random.randn(seq_len_k, d_k)
V = np.random.randn(seq_len_k, d_v)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状:{output.shape}") # (5, 64)
print(f"注意力权重形状:{weights.shape}") # (5, 5)
Multi-Head Attention 实现
class MultiHeadAttention:
def __init__(self, d_model=512, num_heads=8):
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 初始化投影矩阵
self.W_q = np.random.randn(d_model, d_model)
self.W_k = np.random.randn(d_model, d_model)
self.W_v = np.random.randn(d_model, d_model)
self.W_o = np.random.randn(d_model, d_model)
def split_heads(self, X):
"""将最后维度分割到多个头"""
batch_size, seq_len, _ = X.shape
# (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
X = X.reshape(batch_size, seq_len, self.num_heads, self.d_k)
return X.transpose(0, 2, 1, 3)
def forward(self, Q, K, V, mask=None):
batch_size = Q.shape[0]
# 1. 线性投影
Q = np.dot(Q, self.W_q) # (batch, seq_len_q, d_model)
K = np.dot(K, self.W_k)
V = np.dot(V, self.W_v)
# 2. 分割多头
Q = self.split_heads(Q) # (batch, num_heads, seq_len_q, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. 并行计算所有头的注意力
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(self.d_k)
if mask is not None:
scores = scores + (mask * -1e9)
attention_weights = softmax(scores, axis=-1)
head_output = np.matmul(attention_weights, V) # (batch, h, seq_len, d_k)
# 4. 拼接头
head_output = head_output.transpose(0, 2, 1, 3)
concat_output = head_output.reshape(batch_size, -1, self.d_model)
# 5. 最终投影
output = np.dot(concat_output, self.W_o)
return output, attention_weights
# 测试
model = MultiHeadAttention(d_model=512, num_heads=8)
batch_size = 2
seq_len = 10
Q = np.random.randn(batch_size, seq_len, 512)
K = np.random.randn(batch_size, seq_len, 512)
V = np.random.randn(batch_size, seq_len, 512)
output, attn = model.forward(Q, K, V)
print(f"Multi-Head 输出形状:{output.shape}") # (2, 10, 512)
使用 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out_linear = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性投影
Q = self.q_linear(Q) # (batch, seq_len, d_model)
K = self.k_linear(K)
V = self.v_linear(V)
# 分割多头
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1) # (batch, h, seq_len, seq_len)
head_out = torch.matmul(attn_weights, V)
# 拼接
head_out = head_out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 输出投影
output = self.out_linear(head_out)
return output, attn_weights
# 使用示例
model = MultiHeadAttention()
Q = torch.randn(2, 10, 512)
K = torch.randn(2, 10, 512)
V = torch.randn(2, 10, 512)
output, attn = model(Q, K, V)
print(f"PyTorch 输出形状:{output.shape}")
注意力可视化
绘制注意力权重热力图
import matplotlib.pyplot as plt
import seaborn as sns
def plot_attention_heatmap(attention_weights, tokens_x, tokens_y, title="Attention Weights"):
"""
绘制注意力权重热力图
Args:
attention_weights: 注意力权重矩阵 (len_y, len_x)
tokens_x: X 轴的 token 列表
tokens_y: Y 轴的 token 列表
"""
plt.figure(figsize=(12, 8))
ax = sns.heatmap(attention_weights,
annot=True,
fmt='.2f',
xticklabels=tokens_x,
yticklabels=tokens_y,
cmap='YlOrRd',
cbar_kws={'label': 'Attention Weight'})
plt.xlabel('Keys / Values')
plt.ylabel('Queries')
plt.title(title)
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('attention_heatmap.png', dpi=300)
plt.show()
# 示例用法
tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat']
attn_matrix = np.random.rand(6, 6)
attn_matrix = attn_matrix / attn_matrix.sum(axis=1, keepdims=True) # 归一化
plot_attention_heatmap(attn_matrix, tokens, tokens, "Self-Attention Visualization")
多头注意力对比
def compare_multiple_heads(attention_heads, tokens, head_names=None):
"""对比不同注意力头的模式"""
n_heads = len(attention_heads)
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()
for i, (ax, attn) in enumerate(zip(axes, attention_heads)):
im = ax.imshow(attn, cmap='Blues', aspect='auto')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens)
ax.set_yticklabels(tokens)
ax.set_title(f'Head {i+1}')
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.savefig('multi_head_comparison.png', dpi=300)
plt.show()
Attention 的变体
1. Global Attention vs Local Attention
Global Attention(全局注意力):
- 关注序列的所有位置
- 计算复杂度 $O(n^2)$
- 适用于短序列
Local Attention(局部注意力):
- 只关注窗口内的位置
- 计算复杂度 $O(n \times w)$,$w$ 为窗口大小
- 适用于长序列
2. Sparse Attention(稀疏注意力)
通过限制注意力范围减少计算量:
- Fixed Window:只关注前后 $k$ 个位置
- Strided:每隔 $k$ 个位置关注一次
- Longformer:结合全局和局部注意力
3. Cross-Attention(交叉注意力)
用于 Encoder-Decoder 架构:
- Query 来自一个序列
- Key, Value 来自另一个序列
- 实现序列间的信息融合
4. Self-Attention vs Cross-Attention
| 类型 | Q来源 | K,V来源 | 应用场景 |
|---|---|---|---|
| Self-Attention | 序列 X | 序列 X | Encoder 内部 |
| Cross-Attention | 序列 Y | 序列 X | Decoder 关注 Encoder |
实际应用案例
1. 机器翻译
源句 (EN): The cat sits on the mat
目标句 (DE): Die Katze sitzt auf der Matte
解码 "Katze" 时的注意力:
- 高度关注 "cat"(直接翻译)
- 适度关注 "The"(冠词一致性)
- 少量关注 "sits"(动词变位线索)
2. 文本摘要
原文:The researcher conducted a comprehensive study on climate change...
摘要:Study on climate change...
注意力聚焦:
- "researcher conducted" → 动作执行者
- "comprehensive study" → 核心事件
- "climate change" → 主题
3. 问答系统
问题:Who invented the telephone?
文档:Alexander Graham Bell was awarded the first U.S. patent for the telephone in 1876...
注意力分布:
- "Alexander Graham Bell" ← 高亮(答案)
- "telephone" ← 匹配问题关键词
- "patent" ← 相关证据
性能优化技巧
1. 内存效率
问题: $O(n^2)$ 的内存占用限制序列长度
解决方案:
- Gradient Checkpointing:牺牲计算换内存
- Mixed Precision Training:使用 FP16 减少内存
- Activation Recomputation:动态重计算激活值
2. 计算加速
Flash Attention:
- GPU 内核优化
- 减少 HBM 访问
- 速度提升 2-3 倍
Kernel Fusion:
- 合并多个 CUDA 内核
- 减少内存带宽压力
3. 分布式训练
- Data Parallel:数据并行
- Model Parallel:模型并行(超大模型)
- Pipeline Parallel:流水线并行
总结
Attention 的核心优势
- 并行化能力强:可同时处理所有位置,训练效率高
- 长距离依赖:任意两位置间的距离为 O(1)
- 可解释性好:注意力权重提供决策依据
- 灵活通用:适用于各种序列任务
Transformer 的成功要素
Transformer = Self-Attention + Multi-Head + Position Encoding + Residual + LayerNorm
- Self-Attention:建立全局依赖关系
- Multi-Head:多子空间特征学习
- Position Encoding:注入顺序信息
- Residual Connection:深层网络训练稳定
- Layer Normalization:加速收敛
影响与展望
Transformer 不仅在 NLP 领域取得巨大成功,还扩展到:
- 计算机视觉:Vision Transformer (ViT)
- 语音处理:Speech Transformer
- 多模态:DALL-E、CLIP
- 蛋白质折叠:AlphaFold2
Attention 机制已经成为深度学习的基础组件之一,理解其原理对掌握现代 AI 技术至关重要。
参考资料
- Vaswani, A., et al. (2017). “Attention Is All You Need”. NeurIPS.
- The Illustrated Transformer
- Attention and Augmented Recurrent Neural Networks
- PyTorch Documentation: nn.MultiheadAttention
- TensorFlow Tutorial: Transformer model for language understanding