import math
import torch
import torch.nn as nn
import torch.nn.functional as F
Input Data
Source | Target |
I am from China | 我来自中国 |
You and me are best friends | 你我是最好的朋友 |
batch_size=2
, num_steps=8
, “<unk>”=0,
“<pad>” = 1, “<bos>”=2, “<eos>”=3.
‘I am from China’ -> \([4, 5, 6, 7, 1, 1, 1, 1]\)
‘You and me are best friends’ -> \([8, 9, 10, 11, 12, 13, 1, 1]\)
\[X:\begin{bmatrix}4&5&6&7&1&1\\8&9&10&11&12&13\end{bmatrix}\ \ X\_valid\_len:\begin{bmatrix}4&6\end{bmatrix}\]
‘我来自中国’ -> \([4, 5, 6, 7, 8, 1, 1, 1]\)
‘你我是最好的朋友’ -> \([9, 4, 11, 12, 13, 14, 15, 16]\)
\[Y:\begin{bmatrix}4&5&6&7&8&1&1&1&1\\9&10&4&11&12&13&14&15&16\end{bmatrix}\ \ Y\_valid\_len:\begin{bmatrix}5&8\end{bmatrix}\]
Basic Functions
sequence_mask
def sequence_mask(X, valid_len, value=0.0):
'''
:param X: (batch_size, seq_len, input_dim)
:param valid_len: (batch_size, ) or (batch_size, seq_len) <I will discuss this!>
(query_lens, num_hiddens) * (key_lens, num_hiddens)^T = (query_lens, key_lens)
:param X: (batch_size * query_lens, num_hiddens)
:param valid_len: (batch_size, ) --torch.repeat_interleave()--> valid_lens: (batch_size*query_lens, )
'''
maxlen = X.shape[1]
mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
X = torch.ones(2, 6, 8) # shape: (batch_size, seq_len, input_dim)
valid_len = torch.tensor([4, 6]).reshape(2, ) # shape: (batch_size, )
print(sequence_mask(X, valid_len, -99))
## tensor([[[ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [-99., -99., -99., -99., -99., -99., -99., -99.],
## [-99., -99., -99., -99., -99., -99., -99., -99.]],
##
## [[ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.],
## [ 1., 1., 1., 1., 1., 1., 1., 1.]]])
(1, seq_len) < (batch_size, 1) \[ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1\end{bmatrix}\\\begin{bmatrix}2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\\\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1&1&1\end{bmatrix}\\\begin{bmatrix}2&2&2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ mask:\begin{bmatrix}\begin{bmatrix}True&False&False\end{bmatrix}\\\begin{bmatrix}True&True&False\end{bmatrix}\end{bmatrix} \]
masked_softmax
def masked_softmax(X, valid_lens):
'''
query: (2, 6, 14) * key: (2, 8, 14)^T = score: (2, 6, 8)
score: (2, 6, 8) * value: (2, 8, 14) = (2, 6, 14)
masked_softmax() is used to mask score.
:param X: (batch_size, query_lens, key_lens)
:param valid_lens: (batch_size, )
'''
if valid_lens is None:
return F.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else: # I will discuss this after!
valid_lens = valid_lens.reshape(-1)
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return F.softmax(X.reshape(shape), dim=-1)
masked_softmax(torch.rand(2, 6, 8), torch.tensor([4, 6]))
## tensor([[[0.2673, 0.2529, 0.2307, 0.2491, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2909, 0.1846, 0.3504, 0.1742, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2037, 0.3266, 0.1886, 0.2812, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2802, 0.2231, 0.3009, 0.1957, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.3052, 0.2109, 0.1854, 0.2984, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.3103, 0.2003, 0.2468, 0.2426, 0.0000, 0.0000, 0.0000, 0.0000]],
##
## [[0.2064, 0.1458, 0.1180, 0.1641, 0.1970, 0.1687, 0.0000, 0.0000],
## [0.1262, 0.1212, 0.1410, 0.1872, 0.2133, 0.2111, 0.0000, 0.0000],
## [0.2008, 0.1291, 0.0983, 0.1256, 0.2040, 0.2423, 0.0000, 0.0000],
## [0.1648, 0.1260, 0.2222, 0.1810, 0.1084, 0.1975, 0.0000, 0.0000],
## [0.1699, 0.1610, 0.1863, 0.2672, 0.1078, 0.1078, 0.0000, 0.0000],
## [0.1377, 0.1113, 0.2481, 0.1235, 0.2665, 0.1130, 0.0000, 0.0000]]])
DotProductAttention
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
'''
:param queries: (batch_size, query_lens, num_hiddens)
:param keys: (batch_size, key_lens, num_hiddens)
:param values: (batch_size, value_lens, num_hiddens)
'''
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) # O(n*n*d) self-Attention
self.attention_weights = masked_softmax(scores, valid_lens) # O(n*n) self-Attention
return torch.bmm(self.dropout(self.attention_weights), values) # O(n*n*d) self-Attention
queries, keys, values = torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14))
attention = DotProductAttention(dropout=0.5)
attention.eval()
## DotProductAttention(
## (dropout): Dropout(p=0.5, inplace=False)
## )
print(attention(queries, keys, values, torch.tensor([3, 4])).shape)
## torch.Size([2, 6, 14])
Comparison between CNN and Self-Attention in sequence training
Suppose an input sequence \(\boldsymbol{X}_{n\times d}\), CNN kernel \(\boldsymbol{W}_{d\times d}\) of size \(k\ (k<<n)\), Self-Attention weight \(W^{(i)}_{d\times d},\ i\in{q, k, v}\).
In Self-Attention, the time complexity is \(O(n^2d)+O(nd^2)\):
\(\boldsymbol{X}_{n\times d}W^{(i)}_{d\times d}\): \(O(3nd^2)\)
\(\boldsymbol{Q}\boldsymbol{K}^T\): \(O(n^2d)\)
Dot Product: \(O(n^2d)\)
In CNN, the time complexity is \(O(nkd^2)\):
one CNN computation: \(O(kd^2)\)
\(n - k + 1\) times CNN computation in 1: \(O(nkd^2)\)
Therefore, the Self-Attention is more easily influenced by the length of an input sequence, but in which the distance of every two tokens is \(O(1)\).
MultiHeadAttention
Suppose an input matrix of dimension (1, seq_lens=6, input_size=14), and number of heads is 2. Head1 processes the red area, and head2 processes the blue area. transpose_qkv function will transpose \((1, 6, 14)\) into \((1*2, 6, 7)\) to facilitate the parallelized computation. transpose_output function will turn it into its original form.
\[\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix}\mathop{\longrightarrow}^{transpose\_qkv()}\begin{matrix}\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\end{bmatrix}\\ \begin{bmatrix}\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix} \end{matrix}\]
For example, suppose I have a sequence of length \(n=2\), hidden nums of \(d=4\).
\[ \begin{bmatrix} \color{red}{q_{11}}&\color{red}{q_{12}}&\color{blue}{q_{13}}&\color{blue}{q_{14}}\\ \color{red}{q_{21}}&\color{red}{q_{22}}&\color{blue}{q_{23}}&\color{blue}{q_{24}} \end{bmatrix} \mathop{\cdot}^\text{multi-head product} \begin{bmatrix} \color{red}{k_{11}}&\color{red}{k_{12}}&\color{blue}{k_{13}}&\color{blue}{k_{14}}\\ \color{red}{k_{21}}&\color{red}{k_{22}}&\color{blue}{k_{23}}&\color{blue}{k_{24}} \end{bmatrix}^T= \begin{bmatrix} \begin{bmatrix} \color{red}{q_{11}}\color{red}{k_{11}}+\color{red}{q_{12}}\color{red}{k_{12}} & \color{red}{q_{11}}\color{red}{k_{21}}+\color{red}{q_{12}}\color{red}{k_{22}}\\ \color{red}{q_{21}}\color{red}{k_{11}}+\color{red}{q_{22}}\color{red}{k_{12}} & \color{red}{q_{21}}\color{red}{k_{21}}+\color{red}{q_{22}}\color{red}{k_{22}} \end{bmatrix}\\ \begin{bmatrix} \color{blue}{q_{13}}\color{blue}{k_{13}}+\color{blue}{q_{14}}\color{blue}{k_{14}} & \color{blue}{q_{13}}\color{blue}{k_{23}}+\color{blue}{q_{14}}\color{blue}{k_{24}}\\ \color{blue}{q_{23}}\color{blue}{k_{13}}+\color{blue}{q_{24}}\color{blue}{k_{14}} & \color{blue}{q_{23}}\color{blue}{k_{23}}+\color{blue}{q_{24}}\color{blue}{k_{24}} \end{bmatrix} \end{bmatrix} \]
def transpose_qkv(X, num_heads):
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
Positional encoding
\[\boldsymbol{Embed\_X}+\boldsymbol{P}=\begin{bmatrix}x_{11}&x_{12}&\cdots&x_{1d}\\x_{21}&x_{22}&\cdots&x_{2d}\\\vdots&\vdots&\ddots&\vdots\\x_{n1}&x_{n2}&\cdots&x_{nd}\end{bmatrix}+\begin{bmatrix}p_{11}&x_{12}&\cdots&p_{1d}\\p_{21}&p_{22}&\cdots&p_{2d}\\\vdots&\vdots&\ddots&\vdots\\p_{n1}&p_{n2}&\cdots&p_{nd}\end{bmatrix}\], where \(n\) represents seq_lens, \(d\) represents embedding size, \(\boldsymbol{P}\) is the positional encoding matrix.
\[p_{i, 2j}=\sin\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\ \ \ \ \ \ \ p_{i, 2j+1}=\cos\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\], where \(i=0,1,\cdots,n-1\) and \(j=0, 1, \cdots,d/2-1\).
Properties:
\[\begin{bmatrix}p_{i+k, 2j}\\p_{i+k, 2j+1}\end{bmatrix}=\begin{bmatrix}\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\\-\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\end{bmatrix}\begin{bmatrix}p_{i,2j}\\p_{i,2j+1}\end{bmatrix}\]
\[\begin{bmatrix}p_{i,0}&p_{i,1}&\cdots&p_{i,d-1}\end{bmatrix}\begin{bmatrix}p_{i+k,0}\\p_{i+k, 1}\\\vdots\\p_{i+k, d-1}\end{bmatrix}=\cos\Big{(}\frac{k}{1000^{0/d}}\Big{)}+\cos\Big{(}\frac{k}{10000^{2/d}}\Big{)}+\cdots+\cos\Big{(}\frac{k}{10000^{(d-2)/d}}\Big{)}\]
Assuming \(d=512\), the inner product of vectors as \(k\) increases is shown below.
import plotly.express as px
import pandas as pd
import numpy as np
def myfunc(k, d):
exponents = np.arange(0, d, 2)/d
a = k / np.power(10000, exponents)
a = np.cos(a)
# print(a)
return np.sum(a)
k = np.arange(0, 200)
y = [myfunc(item, 512) for item in k]
df = pd.DataFrame({'k': k, 'y': y})
fig = px.scatter(df, x='k', y='y', width=768, height=474)
fig.show()
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros((1, max_len, num_hiddens))
X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) /
torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X += self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
Add&Norm
Suppose I have an input of dimension \((batch\_size=2, seq\_lens=2, input\_size=4)\).
\[ \begin{bmatrix}x^{(1)}_{11}&x^{(1)}_{12}&x^{(1)}_{13}&x^{(1)}_{14}\\x^{(1)}_{21}&x^{(1)}_{22}&x^{(1)}_{23}&x^{(1)}_{24}\end{bmatrix}\\ \begin{bmatrix}x^{(2)}_{11}&x^{(2)}_{12}&x^{(2)}_{13}&x^{(2)}_{14}\\x^{(2)}_{21}&x^{(2)}_{22}&x^{(2)}_{23}&x^{(2)}_{24}\end{bmatrix} \]
The normalization operator nn.LayerNorm(4)
is applied on
every token.
\[ mean^{(1)}_1=\frac{1}{4}\sum_j^4x^{(1)}_{1j}\\var^{(1)}_1=\frac{1}{4}\sum_j\big{(}x^{(1)}_{1j}-mean_1^{(1)}\big{)}^2 \]
The normalization operator nn.LayerNorm([2, 4])
is
applied on every input text.
\[ mean^{(1)}=\frac{1}{2\times4}\sum^2_i\sum^4_jx^{(1)}_{ij}\\var^{(1)}=\frac{1}{2\times4}\sum_i^2\sum_j^4\big{(}x^{(1)}_{ij}-mean^{(1)}\big{)}^2 \]
update method: \[ x_{i,j}^{(k)}\rightarrow g\cdot\frac{x_{i,j}^{(k)}-mean^{(k)}}{\sqrt{var^{(k)}}} + b \]
ln1 = nn.LayerNorm(4)
ln2 = nn.LayerNorm([2, 4])
with torch.no_grad():
# X shape: (2, 2, 4)
X = torch.tensor([
[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[5, 6, 7, 8],
[5, 1, 0, -1]]
], dtype=torch.float32)
print(ln1(X))
print(ln2(X))
## tensor([[[-1.3416, -0.4472, 0.4472, 1.3416],
## [-1.3416, -0.4472, 0.4472, 1.3416]],
##
## [[-1.3416, -0.4472, 0.4472, 1.3416],
## [ 1.6465, -0.1098, -0.5488, -0.9879]]])
## tensor([[[-1.5275, -1.0911, -0.6547, -0.2182],
## [ 0.2182, 0.6547, 1.0911, 1.5275]],
##
## [[ 0.3538, 0.6683, 0.9829, 1.2974],
## [ 0.3538, -0.9042, -1.2187, -1.5332]]])
class AddNorm(nn.Module):
def __init__(self, normalized_shape, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
return self.ln(X + self.dropout(Y))
ForwardWiseFFN
class PositionWiseFFN(nn.Module):
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
Encoder&Decoder
Encoder
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block" + str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) # I haven't dived into this line.
self.attention_weights = [None] * len(self.blks) # self.attention_weights is the score matrix
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X
encoder = TransformerEncoder(200, 14, 14, 14, 14, [6, 14], 14, 28, 2, 6, 0.5)
#encoder.eval()
X = torch.ones((2, 6), dtype=torch.long)
valid_lens = torch.tensor([4, 6], dtype=torch.long)
print(encoder(X, valid_lens).shape)
## torch.Size([2, 6, 14])
Decoder
class DecoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
class AttentionDecoder(nn.Module):
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
def attention_weights(self):
raise NotImplementedError
class TransformerDecoder(AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
the first multi-head layer in decoder
query
after being transposed by multi-heads: \((4, 8, 7)\).
key
after being transposed by multi-heads: \((4, 8, 7)\rightarrow^T(4, 7, 8)\).
\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\begin{bmatrix}1&0&0&0&0&0&0&0\\s_{21}&s_{22}&0&0&0&0&0&0\\s_{31}&s_{32}&s_{33}&0&0&0&0&0\\s_{41}&s_{42}&s_{43}&s_{44}&0&0&0&0\\s_{51}&s_{52}&s_{53}&s_{54}&s_{55}&0&0&0\\s_{61}&s_{62}&s_{63}&s_{64}&s_{65}&s_{66}&0&0\\s_{71}&s_{72}&s_{73}&s_{74}&s_{75}&s_{76}&s_{77}&0\\s_{81}&s_{82}&s_{83}&s_{84}&s_{85}&s_{86}&s_{87}&s_{88}\end{bmatrix}, \forall i=2,\cdots8; \sum_j s_{ij}=1\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \]
The above score matrix indicates that the k-th token can only compute self-attention with the previous k tokens.
the second multi-head layer in decoder
query
after being transposed by multi-heads:\((4, 8, 7)\).
key
after being transposed by multi-heads:\((4, 7, 6)\rightarrow^T(4, 7, 6)\).
\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\vdots\\0\cdots0\\0\cdots0\end{bmatrix}_{7\times6}\rightarrow\begin{bmatrix}\cdots&0&0\\\ddots&\vdots&\vdots\\\cdots&0&0\end{bmatrix}_{8\times6}\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots \]
The above score matrix indicates that the invalid tokens in encoder outputs is ignored.
Model
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)
Train And Prediction
Masked softmax loss
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
def forward(self, pred, label, valid_len):
'''
:pred's shape: (batch_size, num_steps, vocab_size)
:label's shape: (batch_size, num_steps)
:valid_len's shape: (batch_size, )
'''
weights = torch.ones_like(label)
weights = sequence_mask(weights, valid_len)
self.reduction = 'none'
unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
pred.permute(0, 2, 1), label)
# Above is the correct code for calculating cross entropy loss when pred and label have batch dimension.
weighted_loss = (unweighted_loss * weights).mean(dim=1)
return weighted_loss # Each sequence has a loss value.
Train function
def train_seq2seq(model, data_iter, lr, num_epochs, tgt_vocab, device):
def xavier_init_weights(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
if type(m) == nn.GRU:
for param in m._flat_weights_names:
if "weight" in param:
nn.init.xavier_uniform_(m._parameters[param])
model.apply(xavier_init_weights)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss = MaskedSoftmaxCELoss()
model.train()
for epoch in range(num_epochs):
myloss = 0
for batch in data_iter:
optimizer.zero_grad()
X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
# below is called teacher forcing
bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)
dec_input = torch.cat([bos, Y[:, :-1]], 1)
Y_hat, _ = model(X, dec_input, X_valid_len)
l = loss(Y_hat, Y, Y_valid_len)
l.sum().backward()
# d2l.grad_clipping(net, 1)
num_tokens = Y_valid_len.sum()
optimizer.step()
myloss += l.sum() / num_tokens
if (epoch + 1) % 10 == 0:
print("loss: {:.4f}".format(myloss))
Prediction
def truncate_pad(line, num_steps, padding_token):
if len(line) > num_steps:
return line[:num_steps]
else:
return line + [padding_token] * (num_steps - len(line))
def predict_seq2seq(model, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False):
model.eval()
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
enc_outputs = model.encoder(enc_X, enc_valid_len)
dec_state = model.decoder.init_state(enc_outputs, enc_valid_len)
dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
output_seq, attention_weight_seq = [], []
for _ in range(num_steps):
Y, dec_state = model.decoder(dec_X, dec_state)
dec_X = Y.argmax(dim=2)
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
if save_attention_weights:
attention_weight_seq.append(model.decoder.attention_weights)
if pred == tgt_vocab['<eos>']:
break
output_seq.append(pred)
return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
Transformer Code
The following code is a copy of all aforementioned Transformer architecture code, convenient for copying.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def sequence_mask(X, valid_len, value=0.0):
maxlen = X.shape[1]
mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
def masked_softmax(X, valid_lens):
if valid_lens is None:
return F.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else: # I will discuss this after!
valid_lens = valid_lens.reshape(-1)
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return F.softmax(X.reshape(shape), dim=-1)
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
def transpose_qkv(X, num_heads):
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros((1, max_len, num_hiddens))
X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) /
torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X += self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
class AddNorm(nn.Module):
def __init__(self, normalized_shape, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
return self.ln(X + self.dropout(Y))
class PositionWiseFFN(nn.Module):
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block" + str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) # I haven't dived into this line.
self.attention_weights = [None] * len(self.blks) # self.attention_weights is the score matrix
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X
class DecoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
class AttentionDecoder(nn.Module):
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
def attention_weights(self):
raise NotImplementedError
class TransformerDecoder(AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)