Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka Code repository: https://github.com/rasbt/LLMs-from-scratch |
![]() |
Bonus Code for Chapter 5#
Alternative Weight Loading from PyTorch state dicts#
In the main chapter, we loaded the GPT model weights directly from OpenAI
This notebook provides alternative weight loading code to load the model weights from PyTorch state dict files that I created from the original TensorFlow files and uploaded to the Hugging Face Model Hub at https://huggingface.co/rasbt/gpt2-from-scratch-pytorch
This is conceptually the same as loading weights of a PyTorch model from via the state-dict method described in chapter 5:
state_dict = torch.load("model_state_dict.pth")
model.load_state_dict(state_dict)
Choose model#
from importlib.metadata import version
pkgs = ["torch"]
for p in pkgs:
print(f"{p} version: {version(p)}")
---------------------------------------------------------------------------
PackageNotFoundError Traceback (most recent call last)
Cell In[1], line 5
3 pkgs = ["torch"]
4 for p in pkgs:
----> 5 print(f"{p} version: {version(p)}")
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:946, in version(distribution_name)
939 def version(distribution_name):
940 """Get the version string for the named package.
941
942 :param distribution_name: The name of the distribution package to query.
943 :return: The version string for the package as defined in the package's
944 "Version" metadata key.
945 """
--> 946 return distribution(distribution_name).version
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:919, in distribution(distribution_name)
913 def distribution(distribution_name):
914 """Get the ``Distribution`` instance for the named package.
915
916 :param distribution_name: The name of the distribution package as a string.
917 :return: A ``Distribution`` instance (or subclass thereof).
918 """
--> 919 return Distribution.from_name(distribution_name)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:518, in Distribution.from_name(cls, name)
516 return dist
517 else:
--> 518 raise PackageNotFoundError(name)
PackageNotFoundError: No package metadata was found for torch
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
CHOOSE_MODEL = "gpt2-small (124M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
Download file#
file_name = "gpt2-small-124M.pth"
# file_name = "gpt2-medium-355M.pth"
# file_name = "gpt2-large-774M.pth"
# file_name = "gpt2-xl-1558M.pth"
import os
import urllib.request
url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"
if not os.path.exists(file_name):
urllib.request.urlretrieve(url, file_name)
print(f"Downloaded to {file_name}")
Downloaded to gpt2-small-124M.pth
Load weights#
import torch
from llms_from_scratch.ch04 import GPTModel
# For llms_from_scratch installation instructions, see:
# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(torch.load(file_name, weights_only=True))
gpt.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt.to(device);
Generate text#
import tiktoken
from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text
torch.manual_seed(123)
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = generate(
model=gpt.to(device),
idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
max_new_tokens=30,
context_size=BASE_CONFIG["context_length"],
top_k=1,
temperature=1.0
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
Every effort moves forward, but it's not enough.
"I'm not going to sit here and say, 'I'm not going to do this,'
Alternative safetensors file#
In addition, the https://huggingface.co/rasbt/gpt2-from-scratch-pytorch repository contains so-called
.safetensors
versions of the state dictsThe appeal of
.safetensors
files lies in their secure design, as they only store tensor data and avoid the execution of potentially malicious code during loadingIn newer versions of PyTorch (e.g., 2.0 and newer), a
weights_only=True
argument can be used withtorch.load
(e.g.,torch.load("model_state_dict.pth", weights_only=True)
) to improve safety by skipping the execution of code and loading only the weights (this is now enabled by default in PyTorch 2.6 and newer); so in that case loading the weights from the state dict files should not be a concern (anymore)However, the code block below briefly shows how to load the model from these
.safetensor
files
file_name = "gpt2-small-124M.safetensors"
# file_name = "gpt2-medium-355M.safetensors"
# file_name = "gpt2-large-774M.safetensors"
# file_name = "gpt2-xl-1558M.safetensors"
import os
import urllib.request
url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"
if not os.path.exists(file_name):
urllib.request.urlretrieve(url, file_name)
print(f"Downloaded to {file_name}")
Downloaded to gpt2-small-124M.safetensors
# Load file
from safetensors.torch import load_file
gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(load_file(file_name))
gpt.eval();
token_ids = generate(
model=gpt.to(device),
idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
max_new_tokens=30,
context_size=BASE_CONFIG["context_length"],
top_k=1,
temperature=1.0
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
Every effort moves forward, but it's not enough.
"I'm not going to sit here and say, 'I'm not going to do this,'