RMSNorm
Unlike the standard Transformers, Llama uses RMSNorm and shifts the first normalization layer to precede the multi-head self-attention layer, and the second normalization layer to precede the fully connected layer. (Compare to attention mechanism decoder structure)
RMSNorm
is an extension of LayerNorm
and
has less computational overhead of LayerNorm
.
\[ \begin{bmatrix}x^{(1)}_{11}&x^{(1)}_{12}&x^{(1)}_{13}&x^{(1)}_{14}\\x^{(1)}_{21}&x^{(1)}_{22}&x^{(1)}_{23}&x^{(1)}_{24}\end{bmatrix}\\ \begin{bmatrix}x^{(2)}_{11}&x^{(2)}_{12}&x^{(2)}_{13}&x^{(2)}_{14}\\x^{(2)}_{21}&x^{(2)}_{22}&x^{(2)}_{23}&x^{(2)}_{24}\end{bmatrix} \]
\[ rms^{(k)}_i=\sqrt{\frac{1}{4}\sum^{4}_j\big{[}x_{ij}^{(k)}\big{]}^2+\epsilon} \]
, where \(\epsilon\) is a very small number to avoid \(rms\) being \(0\).
update method: \[ [x_{i1}^{(k)},x_{i2}^{(k)},x_{i3}^{(k)},x_{i4}^{(k)}] = \Big{[}\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_1,\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_2,\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_3,\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_4\Big{]} \]
, where \(g_j\) is learnable scaling parameters.
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-8):
super(RMSNorm, self).__init__()
self.d_model = d_model
self.eps = eps
self.scale = nn.Parameter(torch.ones(d_model))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
x = x / rms * self.scale
return x
X = torch.tensor([
[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[5, 6, 7, 8],
[5, 1, 0, -1]]
], dtype=torch.float32)
with torch.no_grad():
rmsnorm = RMSNorm(4)
print(rmsnorm(X))
## tensor([[[ 0.3651, 0.7303, 1.0954, 1.4606],
## [ 0.7581, 0.9097, 1.0613, 1.2130]],
##
## [[ 0.7581, 0.9097, 1.0613, 1.2130],
## [ 1.9245, 0.3849, 0.0000, -0.3849]]])
SwiGLU
Swish function
\[Swish(\vec{x};\beta)=\vec{x}\cdot\text{Sigmoid}(\beta\cdot\vec{x})\]
Swish function’s graph is shown below when \(\beta\) is close to positive infinity.
GLU (Gated Linear Units)
\[GLU(\vec{x};\boldsymbol{W}_1,\boldsymbol{W}_2,b_1,b_2)=\text{Sigmoid}(\vec{x}\boldsymbol{W}_1+b_1)\odot(\vec{x}\boldsymbol{W}_2+b_2)\]
\(\odot\) is Hadamard product. \(\boldsymbol{W}_1\), \(\boldsymbol{W}_2\), \(b_1\), \(b_2\) are learnable parameters.
SwiGLU
\[SwiGLU=Swish(\vec{x}\boldsymbol{W}_1; \beta)\odot(\vec{x}\boldsymbol{W}_2)\]
This SwiGLU function is used to replace ReLU function in Llama.
RoPE
Rotary positional embeddings (RoPE) is used to learn the relative positional relationships between query and key in self-attention layers. It has advantage over very long input sentence.
Suppose I have a series with \(8\) tokens and \(2\) hidden dims.
\[ \color{red}{\vec{q}_2} \mathop{\longleftarrow}^{W_q} \begin{bmatrix} x_{11}&x_{12}\\ \color{red}{x_{21}}&\color{red}{x_{22}}\\ \vdots&\vdots\\ x_{81}&x_{82} \end{bmatrix} \ \ \color{red}{\vec{k}_8} \mathop{\longleftarrow}^{W_k} \begin{bmatrix} x_{11}&x_{12}\\ x_{21}&x_{22}\\ \vdots&\vdots\\ \color{red}{x_{81}}&\color{red}{x_{82}} \end{bmatrix} \]
In Attention mechanism, \(\vec{q}_2\) and \(\vec{k}_8\) will be dot produected directly. But when using RoPE, \(\vec{q}_2\) and \(\vec{k}_8\) will be rotated first.
\[ f(\vec{q}_2)=\vec{q}_2e^{i2\theta}=\big{(}q_2^{(1)}+iq_2^{(2)}\big{)}(\cos2\theta+i\sin2\theta)\\ f(\vec{k}_8)=\vec{k}_8e^{i8\theta}=\big{(}k_8^{(1)}+ik_8^{(2)}\big{)}(\cos8\theta+i\sin8\theta) \]
Then use the following function to calculate the rotary product.
\[ \begin{bmatrix}q_2^{(1)}&q_2^{(2)}\end{bmatrix} \begin{bmatrix} \cos(2-8)\theta&-\sin(2-8)\theta\\ \sin(2-8)\theta&\cos(2-8)\theta \end{bmatrix} \begin{bmatrix}k_8^{(1)}\\k_8^{(2)}\end{bmatrix} \]
In general cases, hidden dims will be 4096 and more. So, for \(\vec{q}_m\) and \(\vec{k}_n\) will be include in the following formula.
\[ \vec{q}_m^T \begin{bmatrix} \cos(m-n)\theta_0&-\sin(m-n)\theta_0&0&0&\cdots\\ \sin(m-n)\theta_0&\cos(m-n)\theta_0&0&0&\cdots\\ 0&0&\cos(m-n)\theta_1&-\sin(m-n)\theta_1&\cdots\\ 0&0&\sin(m-n)\theta_1&\cos(m-n)\theta_1&\cdots\\ \vdots&\vdots&\vdots&\vdots&\ddots \end{bmatrix} \vec{k}_n \]
In Llama, there will no longer be absolute positional embedding, the new replacement of which is RoPE in self-attention layers.
Optimization of Attention Mechanism
Suppose I have a \(n\times d\) matrix, \(n\) is the sequence’s length and \(d\) is the hidden nums. The time complexity of self-attention calculation is \(O(n^2d)\).
Therefore, long sequence will cost huge computational resources.
Sparse attention
sparse attention based on position
The following graph shows basic types of position-based sparse attention: (1) Global Attention; (2) Band Attention; (3) Dilated Attention; (4) Random Attention; (5) Block Attention.
A node or a row/column in the graph represents a token.
The following graph shows advanced types of position-based sparse attention: (1) Star-Transformer; (2) Longformer; (3) ETC; (4) BigBird.
sparse attention based on content
Routing Transformer: use k-means cluster to select keys for query. The keys must be in the same cluster with the query.
Reformer:
FlashAttention
FlashAttention algorithm changes the calculation process on GPU to speed up self-attention calculation speed.
Multi-Query Attention
Different heads of query share a single set of keys and values. MQA reduces the head number of key and value to 1, so that the parameters in self-attention layer are significantly reduced.
For example, in multi-head attention, the parameters of q,k,v is \(28\times28\times3\) for \(W_{q\ \ (28, 28)}\), \(W_{k\ \ (28, 28)}\), \(W_{v\ \ (28, 28)}=2,352\). In multi-query attention, the parameters of q,k,v is \(28\times(28+2\times7)=1,176\), where \(7\) is the dimension of a single head.
However, MQA may not work well because it reduces parameters dramatically. It is the fastest but lower in quality. MQA is used in ChatGLM2 and Google Gemini.
Grouped-Query Attention
GQA is a compromise between MHA and MQA. GQA is used in Llama2, Mistral and Google gemma2. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Data Preparation of LLM
Tokenizer
Data size
Llama2-70b has training tokens of \(2\)T.
Data quality
Data diversity
Llama1, Llama2, Llama3
parameter size | training token | |
Llama | 63B | 1.4T |
Llama2 | 70B | 2T |
Llama3 | 70B | 15T |
Llama1
Network structure:
Use RMSNorm to replace LayerNorm. RMSNorm is set before self-attention layer and MLP layer.
Use SwiGLU to replace ReLU
Use RePE to replace Positional Embedding.
Optimizer: AdamW
Pretraining.
Llama2
Llama2-7B, Llama2-13B use the same architecture as Llama1. But Llama2-70b use GQA to replace MHA.
Pretraining + SFT.
Llama3
Tokenizer vocabulary size of \(128\)K.
Llama3-8B and Llama3-70B use GQA to replace MHA.
Llama3 is trained on long sequence of \(8,192\) tokens.
Pretraining + SFT.