Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka Code repository: https://github.com/rasbt/LLMs-from-scratch |
![]() |
Converting a From-Scratch GPT Architecture to Llama 2#
In this notebook, we convert the original GPT architecture into a Llama 2 model step by step (note the GPT and GPT-2 share the same architecture)
Why not Llama 1 or Llama 3?
The Llama 1 architecture is similar to Llama 2, except that Llama 2 has a larger context window (which is nice); the Llama 1 weights are not readily available and have more usage restrictions, so it makes more sense to focus on Llama 2
Regarding Llama 3, I will share a separate notebook to convert Llama 2 to Llama 3 (there are only a few small additional changes)
The explanations are purposefully kept minimal in this notebook not to bloat it unnecessarily and focus on the main code
For more information, please see the Llama 2 paper: Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)

Packages that are being used in this notebook:
from importlib.metadata import version
pkgs = [
"huggingface_hub", # to download pretrained weights
"sentencepiece", # to implement the tokenizer
"torch", # to implement the model
]
for p in pkgs:
print(f"{p} version: {version(p)}")
---------------------------------------------------------------------------
PackageNotFoundError Traceback (most recent call last)
Cell In[1], line 9
3 pkgs = [
4 "huggingface_hub", # to download pretrained weights
5 "sentencepiece", # to implement the tokenizer
6 "torch", # to implement the model
7 ]
8 for p in pkgs:
----> 9 print(f"{p} version: {version(p)}")
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:946, in version(distribution_name)
939 def version(distribution_name):
940 """Get the version string for the named package.
941
942 :param distribution_name: The name of the distribution package to query.
943 :return: The version string for the package as defined in the package's
944 "Version" metadata key.
945 """
--> 946 return distribution(distribution_name).version
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:919, in distribution(distribution_name)
913 def distribution(distribution_name):
914 """Get the ``Distribution`` instance for the named package.
915
916 :param distribution_name: The name of the distribution package as a string.
917 :return: A ``Distribution`` instance (or subclass thereof).
918 """
--> 919 return Distribution.from_name(distribution_name)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:518, in Distribution.from_name(cls, name)
516 return dist
517 else:
--> 518 raise PackageNotFoundError(name)
PackageNotFoundError: No package metadata was found for huggingface_hub
1. Convert the GPT model implementation step by step#
In this section, we go through the GPT model code from chapter 4 and modify it step by step to implement the Llama 2 architecture
Later, we load the original Llama 2 weights shared by Meta AI
1.1 Replace LayerNorm with RMSNorm layer#
First, we replace LayerNorm by Root Mean Square Layer Normalization (RMSNorm)
LayerNorm normalizes inputs using mean and variance, while RMSNorm uses only the root mean square, which improves computational efficiency
The RMSNorm operation is as follows, where \(x\) is the input \(\gamma\) is a trainable parameter (vector), and \(\epsilon\) is a small constant to avoid zero-division errors:
For more details, please see the paper Root Mean Square Layer Normalization (2019)
import torch
import torch.nn as nn
#####################################
# Chapter 4
#####################################
# class LayerNorm(nn.Module):
# def __init__(self, emb_dim):
# super().__init__()
# self.eps = 1e-5
# self.scale = nn.Parameter(torch.ones(emb_dim))
# self.shift = nn.Parameter(torch.zeros(emb_dim))
# def forward(self, x):
# mean = x.mean(dim=-1, keepdim=True)
# var = x.var(dim=-1, keepdim=True, unbiased=False)
# norm_x = (x - mean) / torch.sqrt(var + self.eps)
# return self.scale * norm_x + self.shift
class RMSNorm(nn.Module):
def __init__(self, emb_dim, eps=1e-5):
super().__init__()
self.eps = eps
self.emb_dim = emb_dim
self.weight = nn.Parameter(torch.ones(emb_dim)).float()
def forward(self, x):
means = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(means + self.eps)
return (x_normed * self.weight).to(dtype=x.dtype)
The following code cell checks that this implementation works the same as PyTorch’s built-in implementation:
torch.manual_seed(123)
example_batch = torch.randn(2, 3, 4)
rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
1.2 Replace GELU with SiLU activation#
Llama uses the SiLU activation function (instead of GELU), which is also known as the Swish function:
For more information, see the SiLU paper: Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning (2017)
#####################################
# Chapter 4
#####################################
# class GELU(nn.Module):
# def __init__(self):
# super().__init__()
# def forward(self, x):
# return 0.5 * x * (1 + torch.tanh(
# torch.sqrt(torch.tensor(2.0 / torch.pi)) *
# (x + 0.044715 * torch.pow(x, 3))
# ))
class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
silu = SiLU()
assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))
1.3 Update the FeedForward module#
In fact, Llama uses a “Gates Linear Unit” (GLU) variant of SiLU called SwiGLU, which essentially results in a slightly differently structured
FeedForward
moduleSwiGLU uses a gating mechanism in the feedforward layer, with the formula:
Here, \(\text{Linear}_1\) and \(\text{Linear}_2\) are two linear layers, and \(*\) denotes element-wise multiplication
The third linear layer, \(\text{Linear}_3\), is applied after this gated activation
For more information, see SwiGLU paper: GLU Variants Improve Transformer (2020)
#####################################
# Chapter 4
#####################################
# class FeedForward(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# self.layers = nn.Sequential(
# nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
# GELU(),
# nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
# )
# def forward(self, x):
# return self.layers(x)
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
self.silu = SiLU()
def forward(self, x):
x_fc1 = self.fc1(x)
x_fc2 = self.fc2(x)
x = self.silu(x_fc1) * x_fc2
return self.fc3(x)
Note that we also added a
dtype=cfg["dtype"]
setting above, which will allow us to load the model directly in lower precision formats later to reduce memory usage (versus instantiating it in the original 32-bit precision format and then converting it)We also set
bias=False
since Llama doesn’t use any bias units
1.4 Implement RoPE#
In the GPT model, the positional embeddings are implemented as follows:
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
Unlike traditional absolute positional embeddings, Llama uses rotary position embeddings (RoPE), which enable it to capture both absolute and relative positional information simultaneously
The reference paper for RoPE is RoFormer: Enhanced Transformer with Rotary Position Embedding (2021)
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
assert head_dim % 2 == 0, "Embedding dimension must be even"
# Compute the inverse frequencies
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
# Generate position indices
positions = torch.arange(context_length)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
# Precompute sine and cosine
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def compute_rope(x, cos, sin):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
# Split x into first half and second half
x1 = x[..., : head_dim // 2] # First half
x2 = x[..., head_dim // 2 :] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
return x_rotated.to(dtype=x.dtype)
The following is an example of applying RoPE to the
q
andk
tensors:
# Settings
batch_size = 2
context_len = 5
num_heads = 4
head_dim = 16
# Instantiate RoPE parameters
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
# Apply rotary position embeddings
queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)
1.5 Add RoPE to MultiHeadAttention module#
It’s important to note that GPT applies the positional embeddings to the inputs, whereas Llama applies rotations to the query and key vectors in the self-attention mechanism itself
Here, we modify the
MultiHeadAttention
class with the appropriate RoPE codeIn addition, we remove the
qkv_bias
option and hardcode thebias=False
settingAlso, we add a dtype setting to be able to instantiate the model with a lower precision later
Tip: since the
TransformerBlock
s (in the next section) are repeated exactly, we could simplify the code and only initialize the buffers once instead for eachMultiHeadAttention
module; however, we add the precomputed RoPE parameters to theMultiHeadAttention
class so that it can function as a standalone module
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, num_heads, dtype=None): # ,dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
################################### NEW ###################################
# Set bias=False and dtype=dtype for all linear layers below
###########################################################################
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) # Linear layer to combine head outputs
# self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
################################### NEW ###################################
cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
###########################################################################
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
################################### NEW ###################################
keys = compute_rope(keys, self.cos, self.sin)
queries = compute_rope(queries, self.cos, self.sin)
###########################################################################
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
# attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
Below is an example using the
MultiHeadAttention
module on an example input:
# Settings
batch_size = 1
context_len = 100
max_context_len = 4096
embed_dim = 128
num_heads = 4
example_batch = torch.randn((batch_size, context_len, embed_dim))
mha = MultiHeadAttention(
d_in=embed_dim,
d_out=embed_dim,
context_length=max_context_len,
num_heads=num_heads
)
mha(example_batch)
del mha # delete to free up memory
1.6 Update the TransformerBlock module#
At this stage, most of the hard work is already done; we can now update the
TransformerBlock
to use the code we implemented aboveThis means we
replace LayerNorm with RMSNorm
remove dropout
remove the
qkv_bias
settingadd the
dtype
setting
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dtype=cfg["dtype"] # NEW
# dropout=cfg["drop_rate"],
# qkv_bias=cfg["qkv_bias"]
)
self.ff = FeedForward(cfg)
################################### NEW ###################################
# self.norm1 = LayerNorm(cfg["emb_dim"])
# self.norm2 = LayerNorm(cfg["emb_dim"])
self.norm1 = RMSNorm(cfg["emb_dim"])
self.norm2 = RMSNorm(cfg["emb_dim"])
###########################################################################
# self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
# x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
# x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
return x
1.7 Update the model class#
As you may recall from chapter 5, the
TransformerBlock
is a repeated block within the main modelOur Llama model is almost complete; we just have to update the model code surrounding the
TransformerBlock
This means we
remove absolute positional embeddings since we have RoPE embeddings now
replace LayerNorm with RMSNorm
remove dropout
add the dtype setting
# class GPTModel(nn.Module):
class Llama2Model(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
# self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
# self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
################################### NEW ###################################
# self.final_norm = LayerNorm(cfg["emb_dim"])
self.final_norm = RMSNorm(cfg["emb_dim"])
###########################################################################
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
def forward(self, in_idx):
# batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds # + pos_embeds # Shape [batch_size, num_tokens, emb_size]
# x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
2. Initialize model#
The model code is now complete, and we are ready to initialize it
In chapter 5, we used the following config file to specify the 124M-parameter GPT model:
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
For reference, the 1.5B parameter GPT model config is shown below as well:
GPT_CONFIG_1558M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 1600, # Embedding dimension
"n_heads": 25, # Number of attention heads
"n_layers": 48, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
Similarly, we can define a Llama 2 config file for the 7B model (we ignore the other larger models for simplicity here):
LLAMA2_CONFIG_7B = {
"vocab_size": 32000, # Vocabulary size
"context_length": 4096, # Context length
"emb_dim": 4096, # Embedding dimension
"n_heads": 32, # Number of attention heads
"n_layers": 32, # Number of layers
"hidden_dim": 11008, # NEW: Size of the intermediate dimension in FeedForward
"dtype": torch.bfloat16 # NEW: Lower-precision dtype to reduce memory usage
}
Using these settings, we can now initialize a Llama 2 7B model (note that this requires ~26 GB of memory)
model = Llama2Model(LLAMA2_CONFIG_7B)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")
Total number of parameters: 6,738,415,616
As shown above, the model contains 6.7 billion parameters (commonly rounded and referred to as a 7B model)
Additionally, we can calculate the memory requirements for this model using the code below:
def model_memory_size(model, input_dtype=torch.float32):
total_params = 0
total_grads = 0
for param in model.parameters():
# Calculate total number of elements per parameter
param_size = param.numel()
total_params += param_size
# Check if gradients are stored for this parameter
if param.requires_grad:
total_grads += param_size
# Calculate buffer size (non-parameters that require memory)
total_buffers = sum(buf.numel() for buf in model.buffers())
# Size in bytes = (Number of elements) * (Size of each element in bytes)
# We assume parameters and gradients are stored in the same type as input dtype
element_size = torch.tensor(0, dtype=input_dtype).element_size()
total_memory_bytes = (total_params + total_grads + total_buffers) * element_size
# Convert bytes to gigabytes
total_memory_gb = total_memory_bytes / (1024**3)
return total_memory_gb
print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")
float32 (PyTorch default): 52.33 GB
bfloat16: 26.17 GB
Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device);
3. Load tokenizer#
In this section, we are going to load the tokenizer for the model
Llama 2 uses Google’s SentencePiece tokenizer instead of OpenAI’s Tiktoken (but Llama 3 uses Tiktoken)
Meta AI shared the original Llama 2 model weights and tokenizer vocabulary on the Hugging Face Hub
We will download the tokenizer vocabulary from the Hub and load it into SentencePiece
Uncomment and run the following code to install the required libraries:
# !pip install huggingface_hub sentencepiece
Please note that Meta AI requires that you accept the Llama 2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the meta-llama/Llama-2-7b repository to accept the terms
Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on “Settings”

Then, create and copy the access token so you can copy & paste it into the next code cell

from huggingface_hub import login
import json
with open("config.json", "r") as config_file:
config = json.load(config_file)
access_token = config["HF_ACCESS_TOKEN"]
login(token=access_token)
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
After login via the access token, which is necessary to verify that we accepted the Llama 2 licensing terms, we can now download the tokenizer vocabulary:
from huggingface_hub import hf_hub_download
tokenizer_file = hf_hub_download(
repo_id="meta-llama/Llama-2-7b",
filename="tokenizer.model",
local_dir="Llama-2-7b"
)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
To provide a more familiar interface for the tokenizer, we define a small
LlamaTokenizer
wrapper class:
import sentencepiece as spm
class LlamaTokenizer:
def __init__(self, tokenizer_file):
sp = spm.SentencePieceProcessor()
sp.load(tokenizer_file)
self.tokenizer = sp
def encode(self, text):
return self.tokenizer.encode_as_ids(text)
def decode(self, ids):
return self.tokenizer.decode_pieces(ids)
tokenizer = LlamaTokenizer(tokenizer_file)
We can now use the
generate
function to have the Llama 2 model generate new text:
from previous_chapters import generate, text_to_token_ids, token_ids_to_text
# If the `previous_chapters.py` file is not available locally,
# you can import it from the `llms-from-scratch` PyPI package.
# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
# E.g.,
# from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text
torch.manual_seed(123)
token_ids = generate(
model=model,
idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
max_new_tokens=30,
context_size=LLAMA2_CONFIG_7B["context_length"],
top_k=1,
temperature=0.
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
Every effort movesαllRadius deletingpretcc否']; future eer napulate lackус während inter DES издаSchéon로жа Bass differencespadxsnu ;; ctx始
Of course, as we can see above, the text is nonsensical since we haven’t trained the Llama 2 model yet
In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI
4. Load pretrained weights#
We are loading the “meta-llama/Llama-2-7b” base model below, which is a simple text completion model before finetuning
Alternatively, you can load the instruction-finetuned and aligned “meta-llama/Llama-2-7b-chat” model by modifying the string in the next code cell accordingly
weights_file = hf_hub_download(
repo_id="meta-llama/Llama-2-7b",
filename="consolidated.00.pth",
local_dir="Llama-2-7b"
)
weights = torch.load(weights_file, weights_only=True)
The
weights
contains the following tensors (only the first 15 are shown for simplicity):
list(weights.keys())[:15]
['tok_embeddings.weight',
'norm.weight',
'output.weight',
'layers.0.attention.wq.weight',
'layers.0.attention.wk.weight',
'layers.0.attention.wv.weight',
'layers.0.attention.wo.weight',
'layers.0.feed_forward.w1.weight',
'layers.0.feed_forward.w2.weight',
'layers.0.feed_forward.w3.weight',
'layers.0.attention_norm.weight',
'layers.0.ffn_norm.weight',
'layers.1.attention.wq.weight',
'layers.1.attention.wk.weight',
'layers.1.attention.wv.weight']
The following function, modeled after the
load_weights_into_gpt
function in chapter 5, loads the pretrained weights into our Llama 2 model:
def assign(left, right):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
if isinstance(right, torch.Tensor):
return torch.nn.Parameter(right.clone().detach())
else:
return torch.nn.Parameter(torch.tensor(right))
def load_weights_into_llama(model, param_config, params):
model.tok_emb.weight = assign(model.tok_emb.weight, params["tok_embeddings.weight"])
for l in range(param_config["n_layers"]):
# Load attention weights
model.trf_blocks[l].att.W_query.weight = assign(
model.trf_blocks[l].att.W_query.weight,
params[f"layers.{l}.attention.wq.weight"]
)
model.trf_blocks[l].att.W_key.weight = assign(
model.trf_blocks[l].att.W_key.weight,
params[f"layers.{l}.attention.wk.weight"]
)
model.trf_blocks[l].att.W_value.weight = assign(
model.trf_blocks[l].att.W_value.weight,
params[f"layers.{l}.attention.wv.weight"]
)
model.trf_blocks[l].att.out_proj.weight = assign(
model.trf_blocks[l].att.out_proj.weight,
params[f"layers.{l}.attention.wo.weight"]
)
model.trf_blocks[l].norm1.weight = assign(
model.trf_blocks[l].norm1.weight,
params[f"layers.{l}.attention_norm.weight"]
)
# Load FeedForward weights
model.trf_blocks[l].ff.fc1.weight = assign(
model.trf_blocks[l].ff.fc1.weight,
params[f"layers.{l}.feed_forward.w1.weight"]
)
# For some reason w2 and w3 are provided in the wrong order in the weights file
model.trf_blocks[l].ff.fc2.weight = assign(
model.trf_blocks[l].ff.fc2.weight,
params[f"layers.{l}.feed_forward.w3.weight"]
)
model.trf_blocks[l].ff.fc3.weight = assign(
model.trf_blocks[l].ff.fc3.weight,
params[f"layers.{l}.feed_forward.w2.weight"]
)
model.trf_blocks[l].norm2.weight = assign(
model.trf_blocks[l].norm2.weight,
params[f"layers.{l}.ffn_norm.weight"]
)
# Load output layer weights
model.final_norm.weight = assign(model.final_norm.weight, params["norm.weight"])
model.out_head.weight = assign(model.out_head.weight, params["output.weight"])
load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)
model.to(device);
Next, we are ready to use the model for text generation
torch.manual_seed(123)
token_ids = generate(
model=model,
idx=text_to_token_ids("Every effort", tokenizer).to(device),
max_new_tokens=25,
context_size=LLAMA2_CONFIG_7B["context_length"],
top_k=1,
temperature=0.
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication
5. Using the instruction-finetuned model#
As mentioned earlier, above we used the pretrained base model; if you want to use a model capable of following instructions, use the
"meta-llama/Llama-2-7b-chat"
model instead, as shown below
del model # to free up memory
weights_file = hf_hub_download(
repo_id="meta-llama/Llama-2-7b-chat",
filename="consolidated.00.pth",
local_dir="Llama-2-7b-chat"
)
model = Llama2Model(LLAMA2_CONFIG_7B)
load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)
model.to(device);
torch.manual_seed(123)
token_ids = generate(
model=model,
idx=text_to_token_ids("What do llamas eat?", tokenizer).to(device),
max_new_tokens=25,
context_size=LLAMA2_CONFIG_7B["context_length"],
top_k=1,
temperature=0.
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
What do llamas eat?
Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass
What’s next?#
This notebook converted the original GPT-2 architecture into a Llama 2 model
If you are interested in how to convert Llama 2 into Llama 3, Llama 3.1, and Llama 3.2, check out the converting-llama2-to-llama3.ipynb notebook