Llama

Tao Zou

2025-03-30

RMSNorm

Unlike the standard Transformers, Llama uses RMSNorm and shifts the first normalization layer to precede the multi-head self-attention layer, and the second normalization layer to precede the fully connected layer. (Compare to attention mechanism decoder structure)

RMSNorm is an extension of LayerNorm and has less computational overhead of LayerNorm.

\[ \begin{bmatrix}x^{(1)}_{11}&x^{(1)}_{12}&x^{(1)}_{13}&x^{(1)}_{14}\\x^{(1)}_{21}&x^{(1)}_{22}&x^{(1)}_{23}&x^{(1)}_{24}\end{bmatrix}\\ \begin{bmatrix}x^{(2)}_{11}&x^{(2)}_{12}&x^{(2)}_{13}&x^{(2)}_{14}\\x^{(2)}_{21}&x^{(2)}_{22}&x^{(2)}_{23}&x^{(2)}_{24}\end{bmatrix} \]

\[ rms^{(k)}_i=\sqrt{\frac{1}{4}\sum^{4}_j\big{[}x_{ij}^{(k)}\big{]}^2+\epsilon} \]

, where \(\epsilon\) is a very small number to avoid \(rms\) being \(0\).

update method: \[ [x_{i1}^{(k)},x_{i2}^{(k)},x_{i3}^{(k)},x_{i4}^{(k)}] = \Big{[}\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_1,\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_2,\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_3,\frac{x_{i1}^{(k)}}{rms_i^{(k)}}g_4\Big{]} \]

, where \(g_j\) is learnable scaling parameters.

Batch Normalization

Layer Normalization

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.d_model = d_model
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        x = x / rms * self.scale
        return x

X = torch.tensor([
        [[1, 2, 3, 4],
         [5, 6, 7, 8]],
        [[5, 6, 7, 8],
         [5, 1, 0, -1]]
    ], dtype=torch.float32)
with torch.no_grad():
  rmsnorm = RMSNorm(4)
  print(rmsnorm(X))
## tensor([[[ 0.3651,  0.7303,  1.0954,  1.4606],
##          [ 0.7581,  0.9097,  1.0613,  1.2130]],
## 
##         [[ 0.7581,  0.9097,  1.0613,  1.2130],
##          [ 1.9245,  0.3849,  0.0000, -0.3849]]])

SwiGLU

Swish function

\[Swish(\vec{x};\beta)=\vec{x}\cdot\text{Sigmoid}(\beta\cdot\vec{x})\]

Swish function’s graph is shown below when \(\beta\) is close to positive infinity.

GLU (Gated Linear Units)

\[GLU(\vec{x};\boldsymbol{W}_1,\boldsymbol{W}_2,b_1,b_2)=\text{Sigmoid}(\vec{x}\boldsymbol{W}_1+b_1)\odot(\vec{x}\boldsymbol{W}_2+b_2)\]

\(\odot\) is Hadamard product. \(\boldsymbol{W}_1\), \(\boldsymbol{W}_2\), \(b_1\), \(b_2\) are learnable parameters.

SwiGLU

\[SwiGLU=Swish(\vec{x}\boldsymbol{W}_1; \beta)\odot(\vec{x}\boldsymbol{W}_2)\]

This SwiGLU function is used to replace ReLU function in Llama.

RoPE

Rotary positional embeddings (RoPE) is used to learn the relative positional relationships between query and key in self-attention layers. It has advantage over very long input sentence.

Suppose I have a series with \(8\) tokens and \(2\) hidden dims.

\[ \color{red}{\vec{q}_2} \mathop{\longleftarrow}^{W_q} \begin{bmatrix} x_{11}&x_{12}\\ \color{red}{x_{21}}&\color{red}{x_{22}}\\ \vdots&\vdots\\ x_{81}&x_{82} \end{bmatrix} \ \ \color{red}{\vec{k}_8} \mathop{\longleftarrow}^{W_k} \begin{bmatrix} x_{11}&x_{12}\\ x_{21}&x_{22}\\ \vdots&\vdots\\ \color{red}{x_{81}}&\color{red}{x_{82}} \end{bmatrix} \]

In Attention mechanism, \(\vec{q}_2\) and \(\vec{k}_8\) will be dot produected directly. But when using RoPE, \(\vec{q}_2\) and \(\vec{k}_8\) will be rotated first.

\[ f(\vec{q}_2)=\vec{q}_2e^{i2\theta}=\big{(}q_2^{(1)}+iq_2^{(2)}\big{)}(\cos2\theta+i\sin2\theta)\\ f(\vec{k}_8)=\vec{k}_8e^{i8\theta}=\big{(}k_8^{(1)}+ik_8^{(2)}\big{)}(\cos8\theta+i\sin8\theta) \]

Then use the following function to calculate the rotary product.

\[ \begin{bmatrix}q_2^{(1)}&q_2^{(2)}\end{bmatrix} \begin{bmatrix} \cos(2-8)\theta&-\sin(2-8)\theta\\ \sin(2-8)\theta&\cos(2-8)\theta \end{bmatrix} \begin{bmatrix}k_8^{(1)}\\k_8^{(2)}\end{bmatrix} \]

In general cases, hidden dims will be 4096 and more. So, for \(\vec{q}_m\) and \(\vec{k}_n\) will be include in the following formula.

\[ \vec{q}_m^T \begin{bmatrix} \cos(m-n)\theta_0&-\sin(m-n)\theta_0&0&0&\cdots\\ \sin(m-n)\theta_0&\cos(m-n)\theta_0&0&0&\cdots\\ 0&0&\cos(m-n)\theta_1&-\sin(m-n)\theta_1&\cdots\\ 0&0&\sin(m-n)\theta_1&\cos(m-n)\theta_1&\cdots\\ \vdots&\vdots&\vdots&\vdots&\ddots \end{bmatrix} \vec{k}_n \]

In Llama, there will no longer be absolute positional embedding, the new replacement of which is RoPE in self-attention layers.

Optimization of Attention Mechanism

Suppose I have a \(n\times d\) matrix, \(n\) is the sequence’s length and \(d\) is the hidden nums. The time complexity of self-attention calculation is \(O(n^2d)\).

Therefore, long sequence will cost huge computational resources.

Sparse attention

sparse attention based on position

The following graph shows basic types of position-based sparse attention: (1) Global Attention; (2) Band Attention; (3) Dilated Attention; (4) Random Attention; (5) Block Attention.

A node or a row/column in the graph represents a token.

The following graph shows advanced types of position-based sparse attention: (1) Star-Transformer; (2) Longformer; (3) ETC; (4) BigBird.

sparse attention based on content

Routing Transformer: use k-means cluster to select keys for query. The keys must be in the same cluster with the query.

Reformer:

FlashAttention

FlashAttention algorithm changes the calculation process on GPU to speed up self-attention calculation speed.

Multi-Query Attention

Different heads of query share a single set of keys and values. MQA reduces the head number of key and value to 1, so that the parameters in self-attention layer are significantly reduced.

For example, in multi-head attention, the parameters of q,k,v is \(28\times28\times3\) for \(W_{q\ \ (28, 28)}\), \(W_{k\ \ (28, 28)}\), \(W_{v\ \ (28, 28)}=2,352\). In multi-query attention, the parameters of q,k,v is \(28\times(28+2\times7)=1,176\), where \(7\) is the dimension of a single head.

However, MQA may not work well because it reduces parameters dramatically. It is the fastest but lower in quality. MQA is used in ChatGLM2 and Google Gemini.

Grouped-Query Attention

GQA is a compromise between MHA and MQA. GQA is used in Llama2, Mistral and Google gemma2. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Data Preparation of LLM

Tokenizer

Data size

Llama2-70b has training tokens of \(2\)T.

Data quality

Data diversity

Llama1, Llama2, Llama3

parameter size training token
Llama 63B 1.4T
Llama2 70B 2T
Llama3 70B 15T

Llama1

Network structure:

  1. Use RMSNorm to replace LayerNorm. RMSNorm is set before self-attention layer and MLP layer.

  2. Use SwiGLU to replace ReLU

  3. Use RePE to replace Positional Embedding.

Optimizer: AdamW

Pretraining.

Llama2

Llama2-7B, Llama2-13B use the same architecture as Llama1. But Llama2-70b use GQA to replace MHA.

Pretraining + SFT.

Llama3

Tokenizer vocabulary size of \(128\)K.

Llama3-8B and Llama3-70B use GQA to replace MHA.

Llama3 is trained on long sequence of \(8,192\) tokens.

Pretraining + SFT.