LLM Transformer PyTorch模块实现如何处理?
摘要:✨Attention Is All You Need https:arxiv.orgabs1706.03762 为了加深对于transformer架构的理解 Talk is cheap. Show me the code. 以下代码
✨Attention Is All You Need
https://arxiv.org/abs/1706.03762
为了加深对于transformer架构的理解
Talk is cheap. Show me the code.
以下代码皆由Gemini 2.5 Flash生成(已经过验证)
✨Attention(注意力机制)
Scaled Dot-Product Attention (缩放点积注意力)
Scaled Dot-Product Attention 是注意力机制的基础形式,它通过计算查询(Query, Q)和键(Key, K)的点积来衡量它们之间的相似度,然后除以一个缩放因子 $ \sqrt{d_k} $(d_k 是键向量的维度),以防止点积过大导致 softmax 函数进入梯度饱和区。最后,将注意力权重与值(Value, V)进行加权求和,得到最终的输出。
数学表达式为:
$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
import torch
import torch.nn as nn
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
实现缩放点积注意力。
Args:
query (torch.Tensor): 查询张量,形状通常为 (batch_size, num_heads, seq_len_q, d_k)。
key (torch.Tensor): 键张量,形状通常为 (batch_size, num_heads, seq_len_k, d_k)。
value (torch.Tensor): 值张量,形状通常为 (batch_size, num_heads, seq_len_v, d_v)。
注意:seq_len_k 必须等于 seq_len_v。
mask (torch.Tensor, optional): 注意力掩码,形状通常为 (batch_size, 1, seq_len_q, seq_len_k)。
用于掩盖某些位置的注意力得分。
Returns:
torch.Tensor: 注意力机制的输出,形状为 (batch_size, num_heads, seq_len_q, d_v)。
torch.Tensor: 注意力权重,形状为 (batch_size, num_heads, seq_len_q, seq_len_k)。
