1 理论
该模型的特点:完全基于注意力机制,完全摒弃了递归和卷积。
它是一种模型架构,避免了递归,而是完全依赖于注意力机制来绘制输入和输出之间的全局依赖关系。
self-attention:有时也被称为内部注意,是一种将单个序列的不同位置联系起来以计算序列的表示形式的注意机制。
较于传统基于RNN/CNN的网络,减少了内部特征,因此用多头注意来抵消该影响,带该模型仍采用encoder-decoder框架。
2 实践
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple
import numpy as np
import torch.nn.functional as F
import math
class ScaledDotAttention(nn.Module):
def __init__(self, d_k):
"""d_k: attention 的维度"""
super(ScaledDotAttention, self).__init__()
self.d_k = d_k
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
score = torch.einsum("nqhd,nkhd->nhqk", [q, k]) / np.sqrt(self.d_k)
if mask is not None:
score.masked_fill_(mask == 0, -float('Inf'))
attn = F.softmax(score, -1)
context = torch.einsum("nhqk,nkhd->nqhd", [attn, v])
return context, attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
"""
d_model: q/k/v 的输入维度
num_heads: attention的个数
"""
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
assert d_model % num_heads == 0, "d_model % num_heads should be zero"
self.d_k = d_model // num_heads
self.scaled_dot_attn = ScaledDotAttention(self.d_k)
self.W_Q = nn.Linear(self.d_k, self.d_k, bias=False)
self.W_K = nn.Linear(self.d_k, self.d_k, bias=False)
self.W_V = nn.Linear(self.d_k, self.d_k, bias=False)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask):
"""
query:(batch, q_len, d_model):来自前一个decoder层;来自输入;来自输出
key:(batch, k_len, d_model):来自编码器的输出;来自输入;来自输出
value:(batch, v_len, d_model):来自编码器的输出;来自输入;来自输出
"""
N = value.size(0)
query = query.view(N, -1, self.num_heads, self.d_k)
key = key.view(N, -1, self.num_heads, self.d_k)
value = value.view(N, -1, self.num_heads, self.d_k)
query = self.W_Q(query)
key = self.W_K(key)
value = self.W_V(value)
context, attn = self.scaled_dot_attn(query, key, value, mask)
context = self.W_O(context.reshape(N, query.size(1), self.num_heads * self.d_k))
return context, attn
class PositionEncoding(nn.Module):
def __init__(self, d_model, max_len=500):
super(PositionEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
pe.requires_grad = False
for pos in range(max_len):
for i in range(d_model, 2):
pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class EncoderBlock(nn.Module):
def __init__(self, d_model, d_ff, num_heads=8, dropout=0.1):
"""
dropout 应用于每一个子层
"""
super(EncoderBlock, self).__init__()
self.dropout = nn.Dropout(dropout)
self.attn = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.FFN = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model))
def forward(self, query, key, value, mask):
context, _ = self.attn(query, key, value, mask)
x = self.dropout(self.norm1(context + query))
forward = self.FFN(x)
out = self.dropout(self.norm2(x + forward))
return out
class DecoderBlock(nn.Module):
def __init__(self, d_model, d_ff, num_heads=8, dropout=0.1):
super(DecoderBlock, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads)
self.block = EncoderBlock(d_model, d_ff, num_heads, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, key, value, src_mask, tgt_mask):
context, _ = self.attn(x, x, x, tgt_mask)
query = self.dropout(self.norm(context + x))
out = self.block(query, key, value, src_mask)
return out
完整代码:https://github.com/mengjizhiyou/pytorch_model/blob/main/Transformer