使用 PyTorch 从头构建自己的 Llama 3 架构

2024年9月15日 450点热度 0人点赞 0条评论

本文提供从头开始构建 Llama 3 模型的完整架构并对自定义数据集执行训练和推理的分步指南。

读完这篇文章后您将获得什么成果?

  1. 您将深入了解 Llama 3 模型的每个组件在内部的工作原理。
  2. 您将编写代码来构建 Llama 3 的每个组件,然后将它们组装在一起以构建功能齐全的 Llama 3 模型。
  3. 您还将编写代码来使用新的自定义数据集训练您的模型。
  4. 您还将编写代码来执行推理,以便您的 Llama 3 模型可以根据输入提示生成新文本。

先决条件

  • 需要具备 Python 和 Pytorch 的基本知识。
  • 虽然不是强制性的,但对 Transformer 概念(例如自我注意力)以及深度神经网络知识的基本理解肯定会有所帮助。

现在我们知道了我们想要实现的目标,让我们开始一步一步构建一切。

步骤 1:输入块

如上图的 Llama 3 架构图所示,输入块有 3 个组件:文本/提示、标记器和嵌入。

输入块内的组件如何工作?有一句流行的说法“一图胜千言”,让我们查看下面的流程图来了解输入块内的工作流程。

  • 首先,单个或一批文本/提示将被传递到模型中。例如:上面流程图中的“Hello World”。
  • 由于模型无法处理文本,因此输入应始终为数字格式。Tokenizer 有助于将这些文本/提示转换为 token-id(这是词汇表中 token 的索引号表示)。我们将使用流行的 Tiny Shakespeare 数据集来构建词汇表并训练我们的模型。
  • Llama 3 模型中使用的标记器是 TikToken,一种子词标记器。但是,我们将使用字符级标记器来构建模型。主要原因是我们应该知道如何自己构建词汇表和标记器,包括编码和解码函数。这样,​​我们将能够了解一切在幕后是如何运作的,并且我们将完全控制代码。
  • 最后,每个 token-id 将被转换为维度为 128 的嵌入向量(在原始 Llama 3 8B 中为 4096)。然后,嵌入将被传递到下一个称为解码器块的块中。

让我们对输入块进行编码:

# Import necessary libraries
import torch
from torch import nn
from torch.nn import functional as F

import math
import numpy as np
import time
from dataclasses import dataclass
from typing import Optional, Tuple, List
import pandas as pd
from matplotlib import pyplot as plt
### Step 1: Input Block ###

# Using Tiny Shakespeare dataset for character-level tokenizer. Some part of the following character-level tokenizer is referenced from Andrej karpathy's GitHub (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py) which I found is explained very well.
# Load tiny_shakespeare data file (https://github.com/tamangmilan/llama3/blob/main/tiny_shakespeare.txt)

device: str = 'cuda' if torch.cuda.is_available() else 'cpu'   # Assign device to cuda or cpu based on availability

# Load tiny_shakespeare data file.
with open('tiny_shakespeare.txt', 'r') as f:
  data = f.read()

# Prepare vocabulary by taking all the unique characters from the tiny_shakespeare data
vocab = sorted(list(set(data)))

# Training Llama 3 model requires addtional tokens such as <|begin_of_text|>, <|end_of_text|> and <|pad_id|>, we'll add them into vocabulary
vocab.extend(['<|begin_of_text|>','<|end_of_text|>','<|pad_id|>'])
vocab_size = len(vocab)

# Create a mapping between characters with corresponding integer indexes in vocabulary.
# This is important to build tokenizers encode and decode functions.
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

# Tokenizers encode function: take a string, output a list of integers
def encode(s):
  return [stoi[ch] for ch in s]

# Tokenizers decode function: take a list of integers, output a string
def decode(l):
  return ''.join(itos[i] for i in l)

# Define tensor token variable to be used later during model training
token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)
token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)
token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)

prompts = "Hello World"
encoded_tokens = encode(prompts)
decoded_text = decode(encoded_tokens)

### Test: Input Block Code ###
# You need take out the triple quotes below to perform testing
"""
print(f"Lenth of shakespeare in character: {len(data)}")
print(f"The vocabulary looks like this: {''.join(vocab)}\n")
print(f"Vocab size: {vocab_size}")
print(f"encoded_tokens: {encoded_tokens}")
print(f"decoded_text: {decoded_text}")
"""
### Test Results: ###
"""
Lenth of shakespeare in character: 1115394
The vocabulary looks like this: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz<|begin_of_text|><|end_of_text|><|pad_id|>

Vocab size: 68
encoded_tokens: [20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]
decoded_text: Hello World
"""

步骤 2:解码器块

如果您查看上面的架构图,解码器块由以下子组件组成。

  • RMS 范数
  • 旋转位置编码
  • 键值缓存
  • 群组查询注意力机制
  • 前馈网络
  • 解码器块

让我们逐一深入研究每个子组件。

2a. RMS 范数(均方根归一化):

为什么需要 RMNSNorm?在上面的架构图中,你一定注意到输入块(即嵌入向量)的输出经过了RMNSNorm 块。这是因为嵌入向量具有许多维度(Llama3-8b 中为 4096 维),并且总是有可能具有不同范围内的值。这可能导致模型梯度爆炸或消失,从而导致收敛缓慢甚至发散。RMNSNorm 将这些值纳入一定范围,有助于稳定和加速训练过程。这使梯度具有更一致的幅度,从而使模型收敛得更快。

RMSNorm 是如何工作的呢?我们先来看下面的图。

  • 与层归一化一样,RMNSNorm 沿嵌入特征或维度应用。上图的嵌入形状为 [3,3],这意味着每个 token 都有 3 个维度。

示例:我们将 RMNSNorm 应用于第一个标记 X1 的嵌入:

  • 每个维度(即 x11、x12 和 x13)上的标记 X1 的值将分别除以所有这些值的均方根。公式如上图所示。
  • E(Epsilon)是一个较小的常数,为了保证数值稳定性,它被添加到均方根中以避免被零除。
  • 最后,将其与缩放参数Gamma (Y)相乘。每个特征都有一个唯一的 Gamma 参数(就像上图中 dim d1 的 Y1、d2 的 Y2 和 dim d3 的 Y3 一样),这是一个学习参数,可以按比例放大或缩小,从而为标准化带来进一步的稳定性。gamma 参数初始化为值 1(如上图所示)。
  • 正如您在上面的示例中注意到的,嵌入值很大,并且分布范围很广。应用 RMSNorm 后,值要小得多,并且分布范围很小。计算已使用实际 RMSNorm 函数完成。

为什么选择 RMSNorm 而不是层归一化?正如您在上面的示例中注意到的那样,我们没有计算在层归一化的情况下进行的任何均值或方差。因此,我们可以说 RMSNorm 通过避免计算均值和方差来减少计算开销。此外,根据作者的论文,RMSNorm 在不影响准确性的情况下提供了性能优势。

让我们对 RMSNorm 进行编码:

# Step2: The Decoder Block
# Note: Since the Llama 3 model is developed by Meta, so to be in sync with their codebase and for future compatibility,
# I will use most of the code from Meta GitHub with some necessary changes required to achieve our goal.

# Define parameters dataclass: we'll use these parameters during model building, training and inference.
# Note: Since we want to see the results of training and inferencing faster rather than focusing on high accuracy, we're taking lower values for most of the parameters which are set higher in the Llama 3 model.

@dataclass
class ModelArgs:
    dim: int = 512              # embedding dimension
    n_layers: int = 8           # number of model decoder blocks
    n_heads: int = 8            # number of heads for queries embedding
    n_kv_heads: int = 4         # number of heads for keys and values embedding
    vocab_size: int = len(vocab) # Length of vocabulary
    multiple_of: int = 256        # Require to calculate dim of feedfoward network
    ffn_dim_multiplier: Optional[float] = None  # Require to calculate dim of feedfoward network
    norm_eps: float = 1e-5                       # Default Epsilon value set for the RMSNorm calculation
    rope_theta: float = 10000.0   # Default theta value for the RePE calculation

    max_batch_size: int = 10     # Max batch size
    max_seq_len: int = 256         # Max sequence length

    epochs: int = 2500             # Total number of training iteration
    log_interval: int = 10        # Number of interval to print the logs and loss values   
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'   # Assign device to cuda or cpu based on availability 
## Step2a: The RMSNorm

class RMSNorm(nn.Module):
  def __init__(self, dim: int, eps: float = 1e-6):
    super().__init__()
    device = ModelArgs.device
    self.eps = eps
    # Scaling parameter gamma, initialized with one and the no of parameters is equal to the size of dim
    self.weight = nn.Parameter(torch.ones(dim).to(device))

  def _norm(self, x):
    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(device)

  def forward(self, x):
    #Shape: x[bs,seq,dim]
    output = self._norm(x.float()).type_as(x)

    #Shape: x[bs,seq,dim] -> x_norm[bs,seq,dim]
    return output * self.weight

### Test: RMSNorm Code ###
# You need take out the triple quotes below to perform testing
"""
x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)
rms_norm = RMSNorm(dim=ModelArgs.dim)
x_norm = rms_norm(x)

print(f"Shape of x: {x.shape}")
print(f"Shape of x_norm: {x_norm.shape}")
"""
### Test Results: ###
"""
Shape of x: torch.Size([10, 256, 512])
Shape of x_norm: torch.Size([10, 256, 512])
"""

2b. 旋转位置编码(RoPE):

为什么我们需要旋转位置编码 (RoPE)?在讨论为什么之前,让我们先回顾一下到目前为止所做的工作。首先,我们将输入文本转换为嵌入。接下来,我们将 RMNSNorm 应用于嵌入。此时,您一定已经注意到有些不对劲。假设输入文本是“我爱苹果”或“苹果爱我”,模型仍会将这两个句子视为相同并将其学习为相同。因为嵌入中没有定义模型要学习的顺序。因此,顺序对于任何语言模型都非常重要。在 Llama 3 模型架构中,RePE 用于定义句子中每个标记的位置,不仅保持顺序,还保持句子中标记的相对位置。

那么,什么是旋转位置编码,它是如何工作的?如上文“为什么”部分所述,RoPE 是一种位置编码,它对嵌入进行编码,通过添加绝对位置信息来维护句子中标记的顺序,并合并标记之间的相对位置信息。它通过使用称为旋转矩阵的特殊矩阵旋转给定的嵌入来执行编码操作。这种使用旋转矩阵的简单但非常强大的数学推导是 RoPE 的核心。

上图中的旋转矩阵旋转了一个二维向量。但是,Llama 3 模型中的维数为 4096,这要多得多。让我们看看如何对更高维度的嵌入应用旋转。

我们现在知道,嵌入的旋转涉及每个嵌入位置 (m) 值与每对嵌入维度的 theta (θ) 相乘。这就是 RoPE 如何通过实现旋转矩阵来捕​​获绝对位置以及相对位置信息。

注意:在执行旋转之前,旋转矩阵需要转换为极坐标形式,嵌入向量需要转换为复数。旋转完成后,需要将旋转后的嵌入转换回实数以进行注意操作。此外,RoPE 仅适用于查询和键嵌入。它不适用于值嵌入。

让我们深入研究 RoPE 编码:

## Step2b: The RoPE
def precompute_freqs_cis(dim:int, seq_len: int, theta: float=10000.0):
  # Computing Theta value for each dim pair which is dim/2
  device = ModelArgs.device
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2,device=device)[:(dim//2)].float()/dim))

  # Computing range of positions(m) in the sequence
  t = torch.arange(seq_len, dtype=torch.float32, device=device)

  # freqs gives all the Theta value range for all the position of tokens in the sequence
  freqs = torch.outer(t, freqs).to(device)

  # This is the rotation matrix which needs to be converted to Polar form in order to perform rotation to the embedding
  freqs_cis = torch.polar(torch.ones_like(freqs).to(device), freqs).to(device)
  return freqs_cis

def reshape_for_broadcast(freqs_cis, x):
  ndim = x.ndim
  assert 0<=1<ndim
  assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "the last two dimension of freqs_cis, x must match"
  shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]
  return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor)->Tuple[torch.Tensor, torch.Tensor]:
  device = ModelArgs.device
  # Applying rotary positional encoding to both query and key embedding together
  # First: The last dimension of xq and xk embedding needs to be reshaped to make it a pair. As rotation matrix is applied to each pair of dim.
  # Next: convert both xq and xk to complex number as the rotation matrix is only applicable to complex number
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device)    #xq_:[bsz, seq_len, n_heads, head_dim/2]
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device)    #xk_:[bsz, seq_len, n_heads, head_dim/2]

  # The rotation matrix(freqs_cis) dimensions across seq_len(dim=1) and head_dim(dim=3) should match with the embedding
  # Also, the shape freqs_cis should be the same with xq and xk, hence change the shape of freqs_cis:[seq_len,head_dim] -> freqs_cis:[1,seq_len,1,head_dim]
  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)

  #Finally, perform rotation operation by multiplying with freqs_cis.
  #After the rotation is completed, convert both xq_out and xk_out back to real number and return
  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device) #xq_out:[bsz, seq_len, n_heads, head_dim]
  xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device) #xk_out:[bsz, seq_len, n_heads, head_dim]
  return xq_out.type_as(xq), xk_out.type_as(xk)

### Test: RoPE Code ###
# Note: x_norm is calculated during RMSNorm and is being used for testing here.
# You need take out the triple quotes below to perform testing
"""
head_dim = ModelArgs.dim//ModelArgs.n_heads
wq = nn.Linear(ModelArgs.dim, ModelArgs.n_heads * head_dim, bias=False, device=device)
wk = nn.Linear(ModelArgs.dim, ModelArgs.n_kv_heads * head_dim, bias=False, device=device)
xq = wq(x_norm)
xk = wk(x_norm)
print(f"xq.shape: {xq.shape}")
print(f"xk.shape: {xk.shape}")

xq = xq.view(xq.shape[0],xq.shape[1],ModelArgs.n_heads, head_dim)
xk = xk.view(xk.shape[0],xk.shape[1],ModelArgs.n_kv_heads, head_dim)
print(f"xq.re-shape: {xq.shape}")
print(f"xk.re-shape: {xk.shape}")

freqs_cis = precompute_freqs_cis(dim=head_dim, seq_len=ModelArgs.max_seq_len)
print(f"freqs_cis.shape: {freqs_cis.shape}")

xq_rotate, xk_rotate = apply_rotary_emb(xq, xk, freqs_cis)
print(f"xq_rotate.shape: {xq_rotate.shape}")
print(f"xk_rotate.shape: {xk_rotate.shape}")
"""
### Test Results: ###
"""
xq.shape: torch.Size([10, 256, 512])
xk.shape: torch.Size([10, 256, 256])
xq.re-shape: torch.Size([10, 256, 8, 64])
xk.re-shape: torch.Size([10, 256, 4, 64])
freqs_cis.shape: torch.Size([256, 32])
xq_rotate.shape: torch.Size([10, 256, 8, 64])
xk_rotate.shape: torch.Size([10, 256, 4, 64])
"""

2c. KV 缓存(仅在推理时需要):

什么是 KV-Cache?在 Llama 3 架构中,在推理时引入了 KV-Cache 的概念,以 Key 和 Value 缓存的形式存储先前生成的 token。这些缓存将用于计算自注意力以生成下一个 token。只有 key 和 value token 会被缓存,而查询 token 不会被缓存,因此称为 KV Cache。

为什么我们需要KV Cache?我们先看下面的图来理清我们的好奇心。

  • 在图中的 A 块中,在生成 output3 token 时,之前的输出 token(output1,output2)还在计算,这是完全没有必要的。这导致在计算注意力时需要进行额外的矩阵乘法,因此计算资源增加了很多。
  • 在图中的块 B 中,输出标记替换了查询嵌入中的输入标记。KV Cache存储了先前生成的标记。在注意力得分计算期间,我们只需使用查询中的 1 个标记并使用键和值缓存中的先前标记。它将从块 A 到块 B 的矩阵乘法从 3×3 减少到 1×3,减少了近 66%。在现实世界中,由于序列长度和批量大小巨大,这将有助于显著降低计算能力。最后,将始终只生成一个最新的输出标记。这是引入 KV-Cache 的主要原因。

2d. 组查询注意:

组查询注意力机制与之前的模型(例如 Llama 1)中使用的多头注意力机制相同,唯一的区别在于查询使用单独的头,键/值使用单独的头。通常,分配给查询的头的数量是键和值头数量的 n 倍。让我们看一下图表以进一步加深理解。

在给定的图表中,多头注意力在所有查询、键和值中具有相同数量的头,即 n_heads = 8。

组查询注意块有 8 个用于查询的头 (n_heads) 和 4 个用于键和值的头 (n_kv_heads),比查询头少 2 倍。

既然 MultiHead Attention 已经这么好了,为什么还需要 Group query Attention?要回答这个问题,我们需要先回顾一下 KV Cache。KV Cache 有助于大大减少计算资源。然而,随着 KV Cache 存储越来越多的先前 token,内存资源将显著增加。这对于模型性能和财务角度来说都不是一件好事。因此,引入了 Group query Attention。减少 K 和 V 的 head 数量会减少要存储的参数数量,因此使用的内存更少。各种测试结果证明,采用这种方法,模型准确率保持在同一范围内。

让我们用代码来实现它:

## The Attention Block [Step2c: The KV Cache; Step2d: Group Query Attention]
## As mentioned before, the naming convention follows original the meta's LLama3 GitHub

class Attention(nn.Module):
  def __init__(self, args: ModelArgs):
    super().__init__()
    self.args = args
    # Embedding dimension
    self.dim = args.dim
    # Number of heads assigned to Query
    self.n_heads = args.n_heads
    # Number of heads assigned to Key and values. If "None", the number will be same as Query.
    self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
    # Dimension of each head relative to model dimension
    self.head_dim = args.dim // args.n_heads
    # Number of repetition in order to make time Key, Value heads to match Query heads number
    self.n_rep = args.n_heads // args.n_kv_heads

    # Weight initialize for Keys, Querys, Values and Oupt. Notice that the out_feature value of weight for q and kv are based on it's heads
    self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False, device=device)
    self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
    self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
    self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False, device=device)

    # Initialize caches to store Key, Values at start. (KV Cache Implementation)
    self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
    self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)

  def forward(self, x: torch.Tensor, start_pos, inference):
    # Shape of the input embedding: [bsz,seq_len,dim]
    bsz, seq_len, _ = x.shape
    # Mask will be used during 'Training' and is not required for 'inference' due to the use of KV cache.
    mask = None

    xq = self.wq(x)  #x[bsz,seq_len,dim]*wq[dim,n_heads * head_dim] -> q[bsz,seq_len,n_heads * head_dim]
    xk = self.wk(x)  #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> k[bsz,seq_len,n_kv_heads * head_dim]
    xv = self.wv(x)  #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> v[bsz,seq_len,n_kv_heads * head_dim]

    # Reshaping Querys, Keys and Values by their number of heads. (Group Query Attention Implementation)
    xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)      #xq[bsz,seq_len,n_heads, head_dim]
    xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)   #xk[bsz,seq_len,n_kv_heads, head_dim]
    xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)   #xv[bsz,seq_len,n_kv_heads, head_dim]

    # Model - Inference Mode: kv-cache is enabled at inference mode only.
    if inference:
      # Compute rotation matrix for each position in the sequence
      freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len * 2)
      # During inferencing, we should only take the rotation matrix range from the current position of the tokens.
      freqs_cis = freqs_cis[start_pos : start_pos + seq_len]
      # Apply RoPE to Queries and Keys embeddings
      xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

      self.cache_k = self.cache_k.to(xq)
      self.cache_v = self.cache_v.to(xq)
      # Store Keys and Values token embedding into their respective cache [KV Cache Implementation]
      self.cache_k[:bsz, start_pos:start_pos + seq_len] = xk
      self.cache_v[:bsz, start_pos:start_pos + seq_len] = xv

      # Assign all the previous tokens embeddings upto current tokens position to Keys and Values variable for Attention Calculation
      keys = self.cache_k[:bsz, :start_pos + seq_len]
      values = self.cache_v[:bsz, :start_pos + seq_len]

      # At this point, they Keys and Values shape aren't same with Queries Embedding which has to be in order to computer attention score
      # Use repeat_kv function to make Keys,Values shape same as queries shape
      keys = repeat_kv(keys, self.n_rep)      #keys[bsz,seq_len,n_heads,head_dim]
      values = repeat_kv(values, self.n_rep)  #values[bsz,seq_len,n_heads,head_dim]

    # Mode - Training mode: KV-Cache not implemented
    else:
      # Compute rotation matrix and apply RoPE to queries and keys for for training.
      freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len)

      #xq[bsz,seq_len,n_heads, head_dim], xk[bsz,seq_len,n_heads, head_dim]
      xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

      # Use repeat_kv function to make Keys,Values shape same as the queries shape
      #keys[bsz,seq_len,n_heads,head_dim], #values[bsz,seq_len,n_heads,head_dim]
      keys = repeat_kv(xk, self.n_rep)
      values = repeat_kv(xv, self.n_rep)

      # For training mode, we'll compute mask and apply to the attention score later
      mask = torch.full((seq_len, seq_len),float("-inf"),device=self.args.device)
      mask = torch.triu(mask, diagonal=1).to(self.args.device)

    # To compute attention, we'll need to perform a transpose operation to reshape all queries, keys and values bring heads at dim 1 and seq at dim 2
    xq = xq.transpose(1,2)                  #xq[bsz,n_heads,seq_len,head_dim]
    keys = keys.transpose(1,2)              #keys[bsz,n_heads,seq_len,head_dim]
    values = values.transpose(1,2)          #values[bsz,n_heads,seq_len,head_dim]

    # Computing attention score
    scores = torch.matmul(xq, keys.transpose(2,3)).to(self.args.device)/math.sqrt(self.head_dim)
    if mask is not None:
      scores = scores + mask

    # Apply softmax to the attention score
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    # Matrix multiplication of attention score with the values
    output = torch.matmul(scores, values).to(self.args.device)

    # We get the contextual embedding for each head
    # All heads need to be reshaped back and combined to give a single single contextual attention output
    # Shape change: output[bsz,n_heads,seq_len,head_dim] -> output[bsz,seq_len, n_heads,head_dim] -> output[bsz,seq_len, n_heads * head_dim]
    output = output.transpose(1,2).contiguous().view(bsz, seq_len, -1)

    # shape: output [bsz,seq_len,dim]
    return self.wo(output)

# If the number of keys/values heads is less than query heads, this function expands the key/values embeddings with the required number of repetition
def repeat_kv(x:torch.Tensor, n_rep: int)-> torch.Tensor:
  bsz, seq_len, n_kv_heads, head_dim = x.shape
  if n_rep == 1:
    return x
  return (
      x[:,:,:,None,:]
      .expand(bsz,seq_len,n_kv_heads,n_rep, head_dim)
      .reshape(bsz,seq_len,n_kv_heads * n_rep, head_dim)
  )


### Test: Repeat_kv function ###
# note: xk, x_norm is already calculated during RoPE, RMSNorm testing and is being used for testing here.
# You need take out the triple quotes below to perform testing
"""
n_rep = ModelArgs.n_heads // ModelArgs.n_kv_heads
keys = repeat_kv(xk, n_rep)
print(f"xk.shape: {xk.shape}")
print(f"keys.shape: {keys.shape}")

## Test: Attention function
# You need take out the triple quotes below to perform testing

attention = Attention(ModelArgs)
x_out = attention(x_norm,start_pos=0, inference=False)
print(f"x_out.shape: {x_out.shape}")
"""
### Test Results: ###
"""
xk.shape: torch.Size([10, 256, 4, 64])
keys.shape: torch.Size([10, 256, 8, 64])
x_out.shape: torch.Size([10, 256, 512])
"""

2e. 前馈网络(SwiGLU 激活):

前馈网络在解码器块中起什么作用?如上图所示,注意力输出首先在 RMSNorm 期间进行归一化,然后馈入前馈网络。在前馈网络内部,注意力输出嵌入将在其隐藏层中扩展到更高维度,并学习 token 的更复杂特征。

为什么要使用 SwiGLU 而不是 ReLU?我们看一下图来寻找答案。

如上图所示,SwiGLU 函数在正轴上的表现与 ReLU 几乎相同。然而,在负轴上,SwiGLU 输出一些负值,这可能有助于学习较小的 0,而不是 ReLU 情况下的平坦 0。总体而言,根据作者的说法,SwiGLU 的性能优于 ReLU;因此,它被选中。

让我们深入研究一下 FeedForward 代码:

## Step2e: The Feedfoward Network (SwiGLU activation)
class FeedForward(nn.Module):
  def __init__(self, dim:int, hidden_dim:int, multiple_of:int, ffn_dim_multiplier: Optional[float]):
    super().__init__()
    # Models embedding dimension
    self.dim = dim

    # We must use the hidden dimensions calculation shared by Meta which is the ideal one for this model
    # Hidden dimension are calculated such that it is a multiple of 256.
    hidden_dim = int(2 * hidden_dim/3)
    if ffn_dim_multiplier is not None:
      hidden_dim = int(ffn_dim_multiplier * hidden_dim)
    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

    # define hiddne layers weights
    self.w1 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)
    self.w2 = nn.Linear(hidden_dim, self.dim, bias=False, device=device)
    self.w3 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)

  def forward(self, x):
    # Shape: [bsz,seq_len,dim]
    return self.w2(F.silu(self.w1(x)) * self.w3(x))



### Test: FeedForward module ###
# note: x_out is already computed at Attention testing and is being used for testing here.
# You need take out the triple quotes below to perform testing
"""
feed_forward = FeedForward(ModelArgs.dim, 4 * ModelArgs.dim, ModelArgs.multiple_of, ModelArgs.ffn_dim_multiplier)
x_out = rms_norm(x_out)
x_out = feed_forward(x_out)
print(f"feed forward output: x_out.shape: {x_out.shape}")
"""

### Test Results: ###
"""
feed forward output: x_out.shape: torch.Size([10, 256, 512])
"""

2f.解码器块:

如上图(第一张图)所示,解码器块由多个子组件组成,我们在前面的部分(2a — 2f)中已经学习并编码过这些子组件。下面是在解码器块内部执行的逐点运算。

  1. 输入块中的嵌入被输入到 Attention-RMSNorm 块中。这将进一步输入到 Group Query Attention 块中。
  2. 然后,来自输入块的相同嵌入将被添加到注意力输出中。
  3. 之后,注意力输出被馈入 FeedFoward-RMSNorm,并进一步馈入 FeedFoward 网络块。
  4. 然后将 FeedFoward 网络的输出再次与注意力输出相加。
  5. 产生的输出称为解码器输出。然后,此解码器输出被馈送到另一个解码器块作为输入。接下来的 31 个解码器块将重复此相同操作。然后,第 32 个解码器块的最终解码器输出被传递到输出块。

我们在下面的代码中看看这个动作:

## Step2f: The Decoder Block. The class name is assigned as TransformerBlock to match the name of Meta llama 3 code base.

class TransformerBlock(nn.Module):
  def __init__(self, args: ModelArgs):
    super().__init__()
    self.args = args
    # Initilizate RMSNorm for attention
    self.attention_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)
    # Initilizate Attention class
    self.attention = Attention(args)
    # Initilizate RMSNorm for feedfoward class
    self.ff_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)
    # Initilizate feedfoward class
    self.feedforward = FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier)

  def forward(self, x, start_pos, inference):
    # start_pos = token position for inference mode, inference = True for inference and False for training mode
    # i) pass input embedding to attention_norm and then pass to attention block.
    # ii) the output of attention is then added to embedding(before norm)
    h = x + self.attention(self.attention_norm(x), start_pos, inference)

    # i) pass attention output to ff_norm and then pass to the feedforward network.
    # ii) the output of feedforward network is then added to the attention output(before ff_norm)
    out = h + self.feedforward(self.ff_norm(h))
    # Shape: [bsz,seq_len,dim]
    return out


### Test: TransformerBlock ###
# You need take out the triple quotes below to perform testing
"""
x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)
transformer_block = TransformerBlock(ModelArgs)
transformer_block_out = transformer_block(x,start_pos=0, inference=False)
print(f"transformer_block_out.shape: {transformer_block_out.shape}")
"""

### Test Results: ###
"""
transformer_block_out.shape: torch.Size([10, 64, 128])
"""

步骤 3:输出块

最终解码器块的解码器输出将馈送到输出块。它首先被馈送到 RMSNorm。然后,它将馈送到生成 logits 的线性层。接下来,将发生以下两个操作之一。

  • 如果模式为推理,则计算 top_p 概率并生成下一个 token。如果达到最大生成长度或生成句子结尾的 token 作为下一个 token,则生成下一个 token 将会停止。
  • 如果模式为训练,则使用目标标签计算损失,并重复训练,直到达到最大时期长度。

为了更清楚起见,让我们看一下输出块流程图。

最后,让我们将 3 个块(输入块、解码器块和输出块)的所有组件组合在一起。这给出了我们最终的 Llama 3 模型。

让我们编写最终的 Llama 3 模型:

## Step3: The Output Block
# This is the Llama 3 model. Again, the class name is maintained as Transformer to match with Meta Llama 3 model.

class Transformer(nn.Module):
  def __init__(self, params: ModelArgs):
    super().__init__()
    # set all the ModelArgs in params variable
    self.params = params
    # Initilizate embedding class from the input block
    self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

    # Initialize the decoder block and store it inside the ModuleList. 
    # This is because we've 4 decoder blocks in our Llama 3 model. (Official Llama 3 has 32 blocks)
    self.layers = nn.ModuleList()
    for layer_id in range(params.n_layers):
      self.layers.append(TransformerBlock(args=params))

    # Initilizate RMSNorm for the output block
    self.norm = RMSNorm(params.dim, eps = params.norm_eps)
    
    # Initilizate linear layer at the output block.
    self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

  def forward(self, x, start_pos=0, targets=None):
    
    # start_pos = token position for inference mode, inference = True for inference and False for training mode
    # x is the batch of token_ids generated from the texts or prompts using tokenizers.
    # x[bsz, seq_len] -> h[bsz, seq_len, dim]
    h = self.tok_embeddings(x)

    # If the target is none, Inference mode is activated and set to "True" and "False" if Training mode is activated.
    if targets is None:
      inference = True
    else:
      inference = False

    # The embeddings (h) will then pass though all the decoder blocks.
    for layer in self.layers:
      h = layer(h, start_pos, inference)

    # The output from the final decoder block will feed into the RMSNorm
    h = self.norm(h)

    # After normalized, the embedding h will then feed into the Linear layer. 
    # The main task of the Linear layer is to generate logits that maps the embeddings with the vocabulary size.
    # h[bsz, seq_len, dim] -> logits[bsz, seq_len, vocab_size]
    logits = self.output(h).float()
    loss = None

    # Inference mode is activated if the targets is not available
    if targets is None:
      loss = None
    # Training mode is activated if the targets are available. And Loss will be calculated for further model training. 
    else:
      loss = F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1))

    return logits, loss


### Test: Transformer (Llama Model) ###
# You need take out the triple quotes below to perform testing
"""
model = Transformer(ModelArgs).to(ModelArgs.device)
print(model)
"""

我们刚刚构建的 Llama 3 模型看起来非常完美。现在我们可以开始训练过程了。

步骤 4:训练我们的 Llama 3 模型:

输出块流程图(步骤 3)提供了训练流程。如果您希望在开始训练之前有更清晰的了解,请再次参考该流程。让我们开始编写训练代码。我还将在代码块中提供必要的解释。

## Step 4: Train Llama 3 Model:

# Create a dataset by encoding the entire tiny_shakespeare data token_ids list using the tokenizer's encode function that we've built at the input block section
dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)
print(f"dataset-shape: {dataset.shape}")

# Define function to generate batches from the given dataset
def get_dataset_batch(data, split, args:ModelArgs):
  seq_len = args.max_seq_len
  batch_size = args.max_batch_size
  device = args.device

  train = data[:int(0.8 * len(data))]
  val = data[int(0.8 * len(data)): int(0.9 * len(data))]
  test = data[int(0.9 * len(data)):]

  batch_data = train
  if split == "val":
    batch_data = val

  if split == "test":
    batch_data = test
  
  # Picking random starting points from the dataset to give random samples for training, validation and testing.
  
  ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)
  x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device)
  y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device)
  
  return x,y

### Test: get_dataset function ###
"""
xs, ys = get_dataset_batch(dataset, split="train", args=ModelArgs)
print([(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))])
"""

# Define a evaluate loss function to calculate and store training and validation loss for logging and plotting
@torch.no_grad()
def evaluate_loss(model, args:ModelArgs):
  out = {}
  model.eval()

  for split in ["train", "val"]:
    losses = []
    for _ in range(10):      
      xb, yb = get_dataset_batch(dataset, split, args)
      _, loss = model(x=xb, targets=yb)
      losses.append(loss.item())
    out[split] = np.mean(losses)

  model.train()
  return out

# Define a training function to perform model training
def train(model, optimizer, args:ModelArgs):
    epochs = args.epochs
    log_interval = args.log_interval
    device = args.device
    losses = []   
    start_time = time.time()

    for epoch in range(epochs):
        optimizer.zero_grad()
        
        xs, ys = get_dataset_batch(dataset, 'train', args)
        xs = xs.to(device)
        ys = ys.to(device)
        logits, loss = model(x=xs, targets=ys)
        loss.backward()
        optimizer.step()

        if epoch % log_interval == 0:
            batch_time = time.time() - start_time
            x = evaluate_loss(model, args)
            losses += [x]            
            print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}")
            start_time = time.time()
    
    # Print the final validation loss
    print("validation loss: ", losses[-1]['val'])
    # Display the interval losses in plot 
    return pd.DataFrame(losses).plot()

现在,我们已经定义了训练函数。让我们从以下代码块开始训练,并在训练完成后在图中观察训练结果。

## Start training our Llama 3 model
model = Transformer(ModelArgs).to(ModelArgs.device)
optimizer = torch.optim.Adam(model.parameters())

train(model, optimizer, ModelArgs)

上图显示了训练和验证损失图。训练已进行了 2500 多个时期。使用默认 GPU 和 RAM 设置的 Google Colab 完成训练过程大约需要 10 分钟,速度非常快。最后一个时期的验证损失为 2.19,考虑到我们使用的训练数据量和时期数,这个数字是可以接受的。为了显著减少损失,我们必须增加训练数据的大小、增加时期数和提高 GPU 或处理能力。

现在我们已经完成了训练。让我们进入最后一步 — 推理,看看模型在新的输入提示下如何生成输出文本。

步骤5:推理Llama 3模型:

输出块流程图(步骤 3)中提供了推理流程。让我们开始编写推理代码。

## Step 5: Inference Llama 3 Model:
# This function generates text sequences based on provided prompts using the LLama 3 model we've built and trained.

def generate(model, prompts: str, params: ModelArgs, max_gen_len: int=500, temperature: float = 0.6, top_p: float = 0.9):

    # prompt_tokens: List of user input texts or prompts
    # max_gen_len: Maximum length of the generated text sequence.
    # temperature: Temperature value for controlling randomness in sampling. Defaults to 0.6.
    # top_p: Top-p probability threshold for sampling prob output from the logits. Defaults to 0.9.
    # prompt_tokens = [0]
    bsz = 1  #For inferencing, in general user just input one prompt which we'll take it as 1-batch
    prompt_tokens = token_bos.tolist() + encode(prompts)
    assert len(prompt_tokens) <= params.max_seq_len, "prompt token length should be small than max_seq_len"
    total_len = min(len(prompt_tokens)+max_gen_len, params.max_seq_len)   

    # this tokens matrix is to store the input prompts and all the output that is generated by model.
    # later we'll use the tokenizers decode function to decode this token to view results in text format
    tokens = torch.full((bsz,total_len), fill_value=token_pad.item(), dtype=torch.long, device=params.device)

    # fill in the prompt tokens into the token matrix
    tokens[:,:len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long, device=params.device)

    #create a prompt_mask_token for later use to identify if the token is a prompt token or a padding token
    # True if it is a prompt token, False if it is a padding token
    input_text_mask = tokens != token_pad.item()

    #now we can start inferencing using one token at a time from the prompt_tokens list starting with the first position.
    prev_pos = 0
    for cur_pos in range(1, total_len):
      with torch.no_grad():
        logits, _ = model(x=tokens[:,prev_pos:cur_pos], start_pos=prev_pos)
      if temperature > 0:      
        probs = torch.softmax(logits[:, -1]/temperature, dim=-1)
        next_token = sample_top_p(probs, top_p)        
      else:
        next_token = torch.argmax(logits[:, -1], dim=-1)        

      next_token = next_token.reshape(-1)

      # only replace the token if it's a padding token
      next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
      tokens[:, cur_pos] = next_token

      prev_pos = cur_pos
      if tokens[:,cur_pos]==token_pad.item() and next_token == token_eos.item():
        break

    output_tokens, output_texts = [], []    

    for i, toks in enumerate(tokens.tolist()):
      # eos_idx = toks.index(token_eos.item())
      if token_eos.item() in toks:
        eos_idx = toks.index(token_eos.item())
        toks = toks[:eos_idx]

      output_tokens.append(toks)
      output_texts.append(decode(toks))
    return output_tokens, output_texts

# Perform top-p (nucleus) sampling on a probability distribution.
# probs (torch.Tensor): Probability distribution tensor derived from the logits.
# p: Probability threshold for top-p sampling.
# According to the paper, Top-p sampling selects the smallest set of tokens whose cumulative probability mass exceeds the threshold p. 
# The distribution is renormalized based on the selected tokens.
def sample_top_p(probs, p):
    probs_sort, prob_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(prob_idx, -1, next_token)    
    # Sampled token indices from the vocabular is returned 
    return next_token

让我们对新的提示进行推理并检查生成的输出:

## Perform the inferencing on user input prompts
prompts = "Consider you what services he has done"
output_tokens, output_texts = generate(model, prompts, ModelArgs)
output_texts = output_texts[0].replace("<|begin_of_text|>", "")
print(output_texts)

## Output ##
"""
Consider you what services he has done o eretrane
adetranytnn i eey i ade hs rcuh i eey,ad hsatsTns rpae,T
eon o i hseflns o i eee ee hs ote i ocal ersl,Bnnlnface
o i hmr a il nwye ademto nt i a ere
h i ees.
Frm oe o etrane o oregae,alh,t orede i oeral
"""

是的,我们可以看到我们的 Llama 3 模型能够根据新提示进行推理并生成文本,尽管考虑到我们用于训练的训练数据量和训练周期,输出似乎并不理想。我相信,有了更大的训练数据,我们将实现更好的准确性。

予人玫瑰,手有余香。如果您觉得本文对您有帮助,请点赞或打赏。

文章评论

您需要 登录 之后才可以评论