Attention Mechanism

Tao Zou

2025-03-30

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

Input Data

Source Target
I am from China 我来自中国
You and me are best friends 你我是最好的朋友

batch_size=2, num_steps=8, “<unk>”=0, “<pad>” = 1, “<bos>”=2, “<eos>”=3.

‘I am from China’ -> \([4, 5, 6, 7, 1, 1, 1, 1]\)

‘You and me are best friends’ -> \([8, 9, 10, 11, 12, 13, 1, 1]\)

\[X:\begin{bmatrix}4&5&6&7&1&1\\8&9&10&11&12&13\end{bmatrix}\ \ X\_valid\_len:\begin{bmatrix}4&6\end{bmatrix}\]

‘我来自中国’ -> \([4, 5, 6, 7, 8, 1, 1, 1]\)

‘你我是最好的朋友’ -> \([9, 4, 11, 12, 13, 14, 15, 16]\)

\[Y:\begin{bmatrix}4&5&6&7&8&1&1&1&1\\9&10&4&11&12&13&14&15&16\end{bmatrix}\ \ Y\_valid\_len:\begin{bmatrix}5&8\end{bmatrix}\]

Basic Functions

sequence_mask

def sequence_mask(X, valid_len, value=0.0):
    '''
    :param X: (batch_size, seq_len, input_dim)
    :param valid_len: (batch_size, )                          or (batch_size, seq_len)  <I will discuss this!>

    (query_lens, num_hiddens) * (key_lens, num_hiddens)^T = (query_lens, key_lens)
    :param X: (batch_size * query_lens, num_hiddens)
    :param valid_len: (batch_size, ) --torch.repeat_interleave()--> valid_lens: (batch_size*query_lens, )
    '''
    maxlen = X.shape[1]
    mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X
X = torch.ones(2, 6, 8)  # shape: (batch_size, seq_len, input_dim)
valid_len = torch.tensor([4, 6]).reshape(2, )  # shape: (batch_size, )
print(sequence_mask(X, valid_len, -99))
## tensor([[[  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [-99., -99., -99., -99., -99., -99., -99., -99.],
##          [-99., -99., -99., -99., -99., -99., -99., -99.]],
## 
##         [[  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.]]])

(1, seq_len) < (batch_size, 1) \[ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1\end{bmatrix}\\\begin{bmatrix}2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\\\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1&1&1\end{bmatrix}\\\begin{bmatrix}2&2&2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ mask:\begin{bmatrix}\begin{bmatrix}True&False&False\end{bmatrix}\\\begin{bmatrix}True&True&False\end{bmatrix}\end{bmatrix} \]

masked_softmax

def masked_softmax(X, valid_lens):
    '''
    query: (2, 6, 14) * key: (2, 8, 14)^T = score: (2, 6, 8)
    score: (2, 6, 8) * value: (2, 8, 14) = (2, 6, 14)
    masked_softmax() is used to mask score.

    :param X: (batch_size, query_lens, key_lens)
    :param valid_lens: (batch_size, )
    '''
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:  # I will discuss this after!
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return F.softmax(X.reshape(shape), dim=-1)
masked_softmax(torch.rand(2, 6, 8), torch.tensor([4, 6]))
## tensor([[[0.2673, 0.2529, 0.2307, 0.2491, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2909, 0.1846, 0.3504, 0.1742, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2037, 0.3266, 0.1886, 0.2812, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2802, 0.2231, 0.3009, 0.1957, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.3052, 0.2109, 0.1854, 0.2984, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.3103, 0.2003, 0.2468, 0.2426, 0.0000, 0.0000, 0.0000, 0.0000]],
## 
##         [[0.2064, 0.1458, 0.1180, 0.1641, 0.1970, 0.1687, 0.0000, 0.0000],
##          [0.1262, 0.1212, 0.1410, 0.1872, 0.2133, 0.2111, 0.0000, 0.0000],
##          [0.2008, 0.1291, 0.0983, 0.1256, 0.2040, 0.2423, 0.0000, 0.0000],
##          [0.1648, 0.1260, 0.2222, 0.1810, 0.1084, 0.1975, 0.0000, 0.0000],
##          [0.1699, 0.1610, 0.1863, 0.2672, 0.1078, 0.1078, 0.0000, 0.0000],
##          [0.1377, 0.1113, 0.2481, 0.1235, 0.2665, 0.1130, 0.0000, 0.0000]]])

DotProductAttention

class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        '''
        :param queries: (batch_size, query_lens, num_hiddens)
        :param keys: (batch_size, key_lens, num_hiddens)
        :param values: (batch_size, value_lens, num_hiddens)
        '''
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)  # O(n*n*d) self-Attention
        self.attention_weights = masked_softmax(scores, valid_lens)  # O(n*n) self-Attention
        return torch.bmm(self.dropout(self.attention_weights), values)  # O(n*n*d) self-Attention
queries, keys, values = torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14))
attention = DotProductAttention(dropout=0.5)
attention.eval()
## DotProductAttention(
##   (dropout): Dropout(p=0.5, inplace=False)
## )
print(attention(queries, keys, values, torch.tensor([3, 4])).shape)
## torch.Size([2, 6, 14])

Comparison between CNN and Self-Attention in sequence training

Suppose an input sequence \(\boldsymbol{X}_{n\times d}\), CNN kernel \(\boldsymbol{W}_{d\times d}\) of size \(k\ (k<<n)\), Self-Attention weight \(W^{(i)}_{d\times d},\ i\in{q, k, v}\).

In Self-Attention, the time complexity is \(O(n^2d)+O(nd^2)\):

  1. \(\boldsymbol{X}_{n\times d}W^{(i)}_{d\times d}\): \(O(3nd^2)\)

  2. \(\boldsymbol{Q}\boldsymbol{K}^T\): \(O(n^2d)\)

  3. Dot Product: \(O(n^2d)\)

In CNN, the time complexity is \(O(nkd^2)\):

  1. one CNN computation: \(O(kd^2)\)

  2. \(n - k + 1\) times CNN computation in 1: \(O(nkd^2)\)

Therefore, the Self-Attention is more easily influenced by the length of an input sequence, but in which the distance of every two tokens is \(O(1)\).

MultiHeadAttention

Suppose an input matrix of dimension (1, seq_lens=6, input_size=14), and number of heads is 2. Head1 processes the red area, and head2 processes the blue area. transpose_qkv function will transpose \((1, 6, 14)\) into \((1*2, 6, 7)\) to facilitate the parallelized computation. transpose_output function will turn it into its original form.

\[\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix}\mathop{\longrightarrow}^{transpose\_qkv()}\begin{matrix}\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\end{bmatrix}\\ \begin{bmatrix}\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix} \end{matrix}\]

For example, suppose I have a sequence of length \(n=2\), hidden nums of \(d=4\).

\[ \begin{bmatrix} \color{red}{q_{11}}&\color{red}{q_{12}}&\color{blue}{q_{13}}&\color{blue}{q_{14}}\\ \color{red}{q_{21}}&\color{red}{q_{22}}&\color{blue}{q_{23}}&\color{blue}{q_{24}} \end{bmatrix} \mathop{\cdot}^\text{multi-head product} \begin{bmatrix} \color{red}{k_{11}}&\color{red}{k_{12}}&\color{blue}{k_{13}}&\color{blue}{k_{14}}\\ \color{red}{k_{21}}&\color{red}{k_{22}}&\color{blue}{k_{23}}&\color{blue}{k_{24}} \end{bmatrix}^T= \begin{bmatrix} \begin{bmatrix} \color{red}{q_{11}}\color{red}{k_{11}}+\color{red}{q_{12}}\color{red}{k_{12}} & \color{red}{q_{11}}\color{red}{k_{21}}+\color{red}{q_{12}}\color{red}{k_{22}}\\ \color{red}{q_{21}}\color{red}{k_{11}}+\color{red}{q_{22}}\color{red}{k_{12}} & \color{red}{q_{21}}\color{red}{k_{21}}+\color{red}{q_{22}}\color{red}{k_{22}} \end{bmatrix}\\ \begin{bmatrix} \color{blue}{q_{13}}\color{blue}{k_{13}}+\color{blue}{q_{14}}\color{blue}{k_{14}} & \color{blue}{q_{13}}\color{blue}{k_{23}}+\color{blue}{q_{14}}\color{blue}{k_{24}}\\ \color{blue}{q_{23}}\color{blue}{k_{13}}+\color{blue}{q_{24}}\color{blue}{k_{14}} & \color{blue}{q_{23}}\color{blue}{k_{23}}+\color{blue}{q_{24}}\color{blue}{k_{24}} \end{bmatrix} \end{bmatrix} \]

def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

Positional encoding

\[\boldsymbol{Embed\_X}+\boldsymbol{P}=\begin{bmatrix}x_{11}&x_{12}&\cdots&x_{1d}\\x_{21}&x_{22}&\cdots&x_{2d}\\\vdots&\vdots&\ddots&\vdots\\x_{n1}&x_{n2}&\cdots&x_{nd}\end{bmatrix}+\begin{bmatrix}p_{11}&x_{12}&\cdots&p_{1d}\\p_{21}&p_{22}&\cdots&p_{2d}\\\vdots&\vdots&\ddots&\vdots\\p_{n1}&p_{n2}&\cdots&p_{nd}\end{bmatrix}\], where \(n\) represents seq_lens, \(d\) represents embedding size, \(\boldsymbol{P}\) is the positional encoding matrix.

\[p_{i, 2j}=\sin\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\ \ \ \ \ \ \ p_{i, 2j+1}=\cos\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\], where \(i=0,1,\cdots,n-1\) and \(j=0, 1, \cdots,d/2-1\).

Properties:

\[\begin{bmatrix}p_{i+k, 2j}\\p_{i+k, 2j+1}\end{bmatrix}=\begin{bmatrix}\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\\-\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\end{bmatrix}\begin{bmatrix}p_{i,2j}\\p_{i,2j+1}\end{bmatrix}\]

\[\begin{bmatrix}p_{i,0}&p_{i,1}&\cdots&p_{i,d-1}\end{bmatrix}\begin{bmatrix}p_{i+k,0}\\p_{i+k, 1}\\\vdots\\p_{i+k, d-1}\end{bmatrix}=\cos\Big{(}\frac{k}{1000^{0/d}}\Big{)}+\cos\Big{(}\frac{k}{10000^{2/d}}\Big{)}+\cdots+\cos\Big{(}\frac{k}{10000^{(d-2)/d}}\Big{)}\]

Assuming \(d=512\), the inner product of vectors as \(k\) increases is shown below.

import plotly.express as px
import pandas as pd
import numpy as np

def myfunc(k, d):
    exponents = np.arange(0, d, 2)/d
    a = k / np.power(10000, exponents)
    a = np.cos(a)
    # print(a)
    return np.sum(a)
k = np.arange(0, 200)
y = [myfunc(item, 512) for item in k]
df = pd.DataFrame({'k': k, 'y': y})
fig = px.scatter(df, x='k', y='y', width=768, height=474)
fig.show()
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / 
             torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    
    def forward(self, X):
        X += self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

Add&Norm

Suppose I have an input of dimension \((batch\_size=2, seq\_lens=2, input\_size=4)\).

\[ \begin{bmatrix}x^{(1)}_{11}&x^{(1)}_{12}&x^{(1)}_{13}&x^{(1)}_{14}\\x^{(1)}_{21}&x^{(1)}_{22}&x^{(1)}_{23}&x^{(1)}_{24}\end{bmatrix}\\ \begin{bmatrix}x^{(2)}_{11}&x^{(2)}_{12}&x^{(2)}_{13}&x^{(2)}_{14}\\x^{(2)}_{21}&x^{(2)}_{22}&x^{(2)}_{23}&x^{(2)}_{24}\end{bmatrix} \]

The normalization operator nn.LayerNorm(4) is applied on every token.

\[ mean^{(1)}_1=\frac{1}{4}\sum_j^4x^{(1)}_{1j}\\var^{(1)}_1=\frac{1}{4}\sum_j\big{(}x^{(1)}_{1j}-mean_1^{(1)}\big{)}^2 \]

The normalization operator nn.LayerNorm([2, 4]) is applied on every input text.

\[ mean^{(1)}=\frac{1}{2\times4}\sum^2_i\sum^4_jx^{(1)}_{ij}\\var^{(1)}=\frac{1}{2\times4}\sum_i^2\sum_j^4\big{(}x^{(1)}_{ij}-mean^{(1)}\big{)}^2 \]

update method: \[ x_{i,j}^{(k)}\rightarrow g\cdot\frac{x_{i,j}^{(k)}-mean^{(k)}}{\sqrt{var^{(k)}}} + b \]

ln1 = nn.LayerNorm(4)
ln2 = nn.LayerNorm([2, 4])
with torch.no_grad():
    # X shape: (2, 2, 4)
    X = torch.tensor([
        [[1, 2, 3, 4],
         [5, 6, 7, 8]],
        [[5, 6, 7, 8],
         [5, 1, 0, -1]]
    ], dtype=torch.float32)
    print(ln1(X))
    print(ln2(X))
## tensor([[[-1.3416, -0.4472,  0.4472,  1.3416],
##          [-1.3416, -0.4472,  0.4472,  1.3416]],
## 
##         [[-1.3416, -0.4472,  0.4472,  1.3416],
##          [ 1.6465, -0.1098, -0.5488, -0.9879]]])
## tensor([[[-1.5275, -1.0911, -0.6547, -0.2182],
##          [ 0.2182,  0.6547,  1.0911,  1.5275]],
## 
##         [[ 0.3538,  0.6683,  0.9829,  1.2974],
##          [ 0.3538, -0.9042, -1.2187, -1.5332]]])
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    
    def forward(self, X, Y):
        return self.ln(X + self.dropout(Y))

ForwardWiseFFN

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

Encoder&Decoder

Encoder

class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))  # I haven't dived into this line.
        self.attention_weights = [None] * len(self.blks)  # self.attention_weights is the score matrix
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X
encoder = TransformerEncoder(200, 14, 14, 14, 14, [6, 14], 14, 28, 2, 6, 0.5)
#encoder.eval()
X = torch.ones((2, 6), dtype=torch.long)
valid_lens = torch.tensor([4, 6], dtype=torch.long)
print(encoder(X, valid_lens).shape)
## torch.Size([2, 6, 14])

Decoder

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

class AttentionDecoder(nn.Module):
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    def attention_weights(self):
        raise NotImplementedError
      
      
class TransformerDecoder(AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, 
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)
    
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
    
    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        return self.dense(X), state
    
    @property
    def attention_weights(self):
        return self._attention_weights

the first multi-head layer in decoder

query after being transposed by multi-heads: \((4, 8, 7)\).

key after being transposed by multi-heads: \((4, 8, 7)\rightarrow^T(4, 7, 8)\).

\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\begin{bmatrix}1&0&0&0&0&0&0&0\\s_{21}&s_{22}&0&0&0&0&0&0\\s_{31}&s_{32}&s_{33}&0&0&0&0&0\\s_{41}&s_{42}&s_{43}&s_{44}&0&0&0&0\\s_{51}&s_{52}&s_{53}&s_{54}&s_{55}&0&0&0\\s_{61}&s_{62}&s_{63}&s_{64}&s_{65}&s_{66}&0&0\\s_{71}&s_{72}&s_{73}&s_{74}&s_{75}&s_{76}&s_{77}&0\\s_{81}&s_{82}&s_{83}&s_{84}&s_{85}&s_{86}&s_{87}&s_{88}\end{bmatrix}, \forall i=2,\cdots8; \sum_j s_{ij}=1\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \]

The above score matrix indicates that the k-th token can only compute self-attention with the previous k tokens.

the second multi-head layer in decoder

query after being transposed by multi-heads:\((4, 8, 7)\).

key after being transposed by multi-heads:\((4, 7, 6)\rightarrow^T(4, 7, 6)\).

\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\vdots\\0\cdots0\\0\cdots0\end{bmatrix}_{7\times6}\rightarrow\begin{bmatrix}\cdots&0&0\\\ddots&\vdots&\vdots\\\cdots&0&0\end{bmatrix}_{8\times6}\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots \]

The above score matrix indicates that the invalid tokens in encoder outputs is ignored.

Model

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)
    
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)

Train And Prediction

Masked softmax loss

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self, pred, label, valid_len):
        '''
        :pred's shape: (batch_size, num_steps, vocab_size)
        :label's shape: (batch_size, num_steps)
        :valid_len's shape: (batch_size, )
        '''
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction = 'none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
            pred.permute(0, 2, 1), label)
        # Above is the correct code for calculating cross entropy loss when pred and label have batch dimension.
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss  # Each sequence has a loss value.

Train function

def train_seq2seq(model, data_iter, lr, num_epochs, tgt_vocab, device):
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    model.apply(xavier_init_weights)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()

    model.train()
    for epoch in range(num_epochs):
        myloss = 0
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            # below is called teacher forcing
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)

            Y_hat, _ = model(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()
            # d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            myloss += l.sum() / num_tokens
        if (epoch + 1) % 10 == 0:
            print("loss: {:.4f}".format(myloss))

Prediction

def truncate_pad(line, num_steps, padding_token):
    if len(line) > num_steps:
        return line[:num_steps]
    else:
        return line + [padding_token] * (num_steps - len(line))
def predict_seq2seq(model, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False):
    model.eval()
    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    
    enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
    enc_outputs = model.encoder(enc_X, enc_valid_len)
    
    dec_state = model.decoder.init_state(enc_outputs, enc_valid_len)
    dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
    
    output_seq, attention_weight_seq = [], []
    for _ in range(num_steps):
        Y, dec_state = model.decoder(dec_X, dec_state)
        dec_X = Y.argmax(dim=2)
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()
        
        if save_attention_weights:
            attention_weight_seq.append(model.decoder.attention_weights)
        
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    
    return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

Transformer Code

The following code is a copy of all aforementioned Transformer architecture code, convenient for copying.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def sequence_mask(X, valid_len, value=0.0):
    maxlen = X.shape[1]
    mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X
  
def masked_softmax(X, valid_lens):
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:  # I will discuss this after!
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return F.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):

        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / 
             torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    
    def forward(self, X):
        X += self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    
    def forward(self, X, Y):
        return self.ln(X + self.dropout(Y))

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))  # I haven't dived into this line.
        self.attention_weights = [None] * len(self.blks)  # self.attention_weights is the score matrix
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

class AttentionDecoder(nn.Module):
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    def attention_weights(self):
        raise NotImplementedError
      
      
class TransformerDecoder(AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, 
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)
    
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
    
    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        return self.dense(X), state
    
    @property
    def attention_weights(self):
        return self._attention_weights

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)
    
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)

Llama