当前位置: 首页 > news >正文

Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 多头注意力机制(Multi-head Attention)原理介绍以及算法实现

锋哥原创的Transformer 大语言模型(LLM)基石视频教程:

https://www.bilibili.com/video/BV1X92pBqEhV

课程介绍

本课程主要讲解Transformer简介,Transformer架构介绍,Transformer架构详解,包括输入层,位置编码,多头注意力机制,前馈神经网络,编码器层,解码器层,输出层,以及Transformer Pytorch2内置实现,Transformer基于PyTorch2手写实现等知识。

Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 多头注意力机制(Multi-head Attention)原理介绍以及算法实现

多头注意力机制的核心思想是:将注意力机制中的查询、键、值映射到多个子空间,分别计算不同的注意力,然后将这些结果拼接起来,再通过一个线性变换输出。

这样做的好处:

  • 信息融合:每个头会从不同的表示空间关注不同的信息,帮助模型学习到更加多样化的表示。

  • 并行计算:多个注意力头可以并行计算,提升计算效率。

多头注意力的计算流程如下:

  1. 线性变换:首先将输入的查询、键和值分别通过不同的权重矩阵进行线性变换,得到多个头的查询、键和值。

  2. 计算每个头的注意力:对于每个头,使用缩放点积注意力计算注意力权重。

  3. 拼接输出:将所有头的输出拼接在一起。

  4. 线性变换:最后,将拼接后的结果通过一个线性变换,得到最终的输出。

作用:理论上来讲,随机事件越多,越能逼近真相。 从不同表示子空间并行捕获信息,增强模型表达能力。

每一组注意力用于将输入映射到不同的子表示空间,这使得模型可以在不同子表示空间中关注不同的位置。整个计算过程可表示为:

按照上面的方法,使用不同的权重矩阵进行 8 次自注意力计算,就可以得到 8 个不同的 Z 矩阵。

接下来就有点麻烦了。因为前馈神经网络层接收的是 1 个矩阵(每个词的词向量),而不是上面的 8 个矩阵。因此,我们需要一种方法将这 8 个矩阵整合为一个矩阵。具体方法如下:

这差不多就是多头注意力的全部内容了。下面将所有内容放到一张图中,以便我们可以统一查看。

通俗解释:

多头注意力机制计算前:

Q=K=V 维度 [3,5,512]

进行多头注意力机制,我们要进行如下操作:

先将最后一次词嵌入维度512进行分割,我们根据论文建议,分成8个头,每个头就会获取(512/8=64)64维度的嵌入表示。Q=K=V=[3,5,8,64]

8个头,就有8组注意力,每组 Qi=Ki=Vi=[3,5,64]

最后8个头注意力都计算完后,我们在对计算结果进行拼接融合,最终的注意力结果标识还是之前的Q=K=V 维度 [3,5,512]

多头注意力机制(Multi-head Attention)算法实现

def create_sequence_mask(seq_len): """ 创建序列掩码(下三角矩阵) """ mask = np.triu(np.ones((3, 8, seq_len, seq_len), dtype=np.uint8), k=1) return torch.from_numpy(1 - mask) # 实现自注意力机制层 def self_attention(query, key, value, mask=None, dropout=None): """ 自注意力机制层 参数: query: 查询张量 [batch_size, seq_len, d_model] key: 关键张量 [batch_size, seq_len, d_model] value: 值张量 [batch_size, seq_len, d_model] mask: 掩码张量 [batch_size, seq_len, seq_len] dropout: Dropout概率 防止模型过拟合,从而提升其泛化能力 返回: 注意力输出和注意力权重 """ # 获取词嵌入维度 512 dk = query.size(-1) print("dk:", dk) # 计算注意力分数 [batch_size, seq_len, seq_len] attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(dk) # 将掩码位置的注意力分数设置为负无穷 if mask is not None: attention_scores = attention_scores.masked_fill(mask == 0, -1e9) # softmax 函数,将注意力分数进行归一化 attention_weights = F.softmax(attention_scores, dim=-1) # 对权重进行dropout,随时失活 if dropout is not None: attention_weights = dropout(attention_weights) return torch.matmul(attention_weights, value), attention_weights # 多头注意力机制 class MultiHeadAttention(nn.Module): """ 参数: d_model: 词嵌入维度(必须为偶数) num_heads: 头的数量 dropout: 随机丢失率 """ def __init__(self, d_model, num_heads, dropout=0.1): super().__init__() self.d_k = d_model // num_heads # 512/8=64 词嵌入维度被分为8个头,每个头维度为64 self.num_heads = num_heads # 头的数量 # 定义Query, Key, Value的线性变换 self.Q_linear = nn.Linear(d_model, d_model) self.K_linear = nn.Linear(d_model, d_model) self.V_linear = nn.Linear(d_model, d_model) # 输出线性变换 self.out = nn.Linear(d_model, d_model) # 创建一个Dropout层 self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): # query key value的维度 [3, 8, 512] mask要求为[3,head,seq_len,seq_len]->[3, 8, 8, 8] # 获取batch_size batch_size = query.shape[0] # 3 # 线性变换Q, K, V Q = self.Q_linear(query) K = self.K_linear(key) V = self.V_linear(value) # 分割成多个头 Q = Q.view(batch_size, -1, self.num_heads, self.d_k) # [3, 5, 8, 64] K = K.view(batch_size, -1, self.num_heads, self.d_k) # [3, 5, 8, 64] V = V.view(batch_size, -1, self.num_heads, self.d_k) # [3, 5, 8, 64] print("Q:", Q.shape) # 转置成方便进行点积操作的形状 Q = Q.transpose(1, 2) # [3, 8, 5, 64] K = K.transpose(1, 2) # [3, 8, 5, 64] V = V.transpose(1, 2) # [3, 8, 5, 64] print("Q:", Q.shape) # 实现多头注意力机制计算 attention_output, attention_weights = self_attention(Q, K, V, mask, self.dropout) print("attention_output:", attention_output.shape) # [3, 8, 5, 64] print("attention_weights:", attention_weights.shape) # [3, 8, 5, 5] # 多头合并 attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_k * self.num_heads) print("attention_output:", attention_output.shape) # [3, 5, 512] # 输出线性变换 output = self.out(attention_output) return output if __name__ == '__main__': vocab_size = 2000 # 词表大小 embedding_dim = 512 # 词嵌入维度的大小 embeddings = Embeddings(vocab_size, embedding_dim) embed_result = embeddings( torch.tensor([[1999, 2, 99, 4, 5], [66, 2, 3, 22, 5], [66, 2, 3, 4, 5]])) print("embed_result.shape:", embed_result.shape) print("embed_result", embed_result) positional_encoding = PositionalEncoding(embedding_dim) result = positional_encoding(embed_result) print("result:", result) print("result.shape:", result.shape) # 测试自注意力机制 # query = key = value = result # mask = create_sequence_mask(5) # dropout = nn.Dropout(0.1) # attention_output, attention_weights = self_attention(query, key, value, mask, dropout) # print("attention_output.shape:", attention_output.shape) # [3, 5, 512] # print("attention_weights.shape:", attention_weights.shape) # [3, 5, 5] mha = MultiHeadAttention(d_model=512, num_heads=8) print(mha) mask = create_sequence_mask(5) result = mha(result, result, result, mask) print("result.shape:", result.shape) # [3, 5, 512]

运行输出:

http://www.cnnetsun.cn/news/2974.html

相关文章:

  • 梦笔记20251211
  • 23、网络应用管理技术全解析
  • 24、Web应用安全:服务器与客户端防护全解析
  • 25、网页客户端安全:跨站脚本攻击与Cookie劫持揭秘
  • 26、Web应用安全测试案例深度剖析
  • 28、Libwhisker与UrlScan:Web安全工具的使用与配置
  • 29、UrlScan 安装、配置与 Windows 更新全指南
  • 30、IISLockdown与UrlScan的安装配置全解析
  • 1、网络应用安全攻防全解析
  • 17、Web会话ID管理与分析全解析
  • 27、Web应用安全测试与防护全解析
  • Android AOSP 15 源码Ubuntu编译
  • 电平 —— 数字电路与通信领域的核心概念
  • PanSearch - 网盘影视资源搜索聚合工具(KaiGe AI出品
  • Pr教程资源合集
  • 8、文件、流与 XML 处理全解析
  • 9、Qt 应用中提供用户帮助的全面指南
  • 10、应用程序国际化与本地化全解析
  • 11、Qt插件开发全解析
  • 13、Qt 数据库开发:从基础到实践
  • 14、Qt网络编程:从客户端协议到套接字的全面指南
  • 15、构建Qt项目:QMake与CMake的全面指南
  • 16、单元测试:提升软件质量的有效策略
  • 17、Qt 开发中的第三方工具、容器、类型与宏
  • 12、并行编程:线程与进程的深入探索
  • Redis篇1——Redis深度剖析:从 5 种对象到 6 大底层结构
  • 14、Linux 系统 I/O 设备管理与驱动详解
  • 15、Linux磁盘缓存机制解析
  • 16、深入理解文件访问机制:从读写操作到内存映射与直接I/O
  • 17、Linux内存交换与页面回收机制解析