LLM Transformer PyTorch模块实现如何处理?

摘要:✨Attention Is All You Need https:arxiv.orgabs1706.03762 为了加深对于transformer架构的理解 Talk is cheap. Show me the code. 以下代码
✨Attention Is All You Need https://arxiv.org/abs/1706.03762 为了加深对于transformer架构的理解 Talk is cheap. Show me the code. 以下代码皆由Gemini 2.5 Flash生成(已经过验证) ✨Attention(注意力机制) Scaled Dot-Product Attention (缩放点积注意力) Scaled Dot-Product Attention 是注意力机制的基础形式,它通过计算查询(Query, Q)和键(Key, K)的点积来衡量它们之间的相似度,然后除以一个缩放因子 $ \sqrt{d_k} $(d_k 是键向量的维度),以防止点积过大导致 softmax 函数进入梯度饱和区。最后,将注意力权重与值(Value, V)进行加权求和,得到最终的输出。 数学表达式为: $$ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ import torch import torch.nn as nn import math def scaled_dot_product_attention(query, key, value, mask=None): """ 实现缩放点积注意力。 Args: query (torch.Tensor): 查询张量,形状通常为 (batch_size, num_heads, seq_len_q, d_k)。 key (torch.Tensor): 键张量,形状通常为 (batch_size, num_heads, seq_len_k, d_k)。 value (torch.Tensor): 值张量,形状通常为 (batch_size, num_heads, seq_len_v, d_v)。 注意:seq_len_k 必须等于 seq_len_v。 mask (torch.Tensor, optional): 注意力掩码,形状通常为 (batch_size, 1, seq_len_q, seq_len_k)。 用于掩盖某些位置的注意力得分。 Returns: torch.Tensor: 注意力机制的输出,形状为 (batch_size, num_heads, seq_len_q, d_v)。 torch.Tensor: 注意力权重,形状为 (batch_size, num_heads, seq_len_q, seq_len_k)。
阅读全文