# A.2 PyTorch 实现示例

以下代码展示了 Transformer 核心组件的 PyTorch 实现。更多与正文数学推导配套的可运行代码示例，可参见[第2章](/llm_internals/di-yi-bu-fen-ji-chu-pian/02_attention/2.2_scaled_dot_product.md)和[第3章](/llm_internals/di-yi-bu-fen-ji-chu-pian/03_components/3.4_feedforward.md)中的内嵌代码。

## 缩放点积注意力

下面的函数支持常见 padding mask 和 causal mask。padding mask 通常形如 `(batch, key_len)`，causal mask 通常形如 `(query_len, key_len)` 或 `(batch, query_len, key_len)`。

```python
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """缩放点积注意力"""
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        if mask.dim() == 2:
            if mask.size(0) == Q.size(0) and mask.size(1) == K.size(-2):
                mask = mask[:, None, None, :]
            elif mask.size(0) == Q.size(-2) and mask.size(1) == K.size(-2):
                mask = mask.unsqueeze(0).unsqueeze(0)
            else:
                raise ValueError("2D mask must be (batch, key_len) or (query_len, key_len)")
        elif mask.dim() == 3:
            mask = mask[:, None, :, :]
        elif mask.dim() != 4:
            raise ValueError("mask must have 2, 3, or 4 dimensions")
        valid_mask = mask != 0
        scores = scores.masked_fill(~valid_mask, torch.finfo(scores.dtype).min)
    attn_weights = F.softmax(scores, dim=-1)
    if mask is not None:
        fully_masked = valid_mask.sum(dim=-1, keepdim=True) == 0
        attn_weights = attn_weights.masked_fill(fully_masked, 0.0)
    output = torch.matmul(attn_weights, V)
    return output, attn_weights
```

## 多头注意力

多头注意力需要 `d_model` 能被注意力头数整除，否则无法把投影后的张量均匀拆成多个头。

```python
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError("d_model must be divisible by n_heads")
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        # 投影并拆分为多头
        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        # 计算注意力
        out, _ = scaled_dot_product_attention(Q, K, V, mask)
        # 合并多头
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(out)
```

## 前馈网络

下面的代码块延续前文导入的 `nn` 和 `F`，展示逐位置前馈网络的最小结构。

```python
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))
```

## 训练循环示例

下面是训练循环骨架，省略了具体模型类、数据加载器和 epoch 配置，重点展示损失计算、梯度裁剪和优化器更新的位置。

```python
model = MyTransformerModel(vocab_size=32000, d_model=512, n_heads=8, n_layers=6)
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

model.train()
for epoch in range(num_epochs):
    for batch in dataloader:
        input_ids, target_ids = batch
        optimizer.zero_grad()
        logits = model(input_ids)
        loss = criterion(logits.reshape(-1, logits.size(-1)), target_ids.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://yeasy.gitbook.io/llm_internals/di-si-bu-fen-mo-xing-yu-qian-yan-pian/appendix/a2_pytorch_examples.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
