Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka

Code repository: https://github.com/rasbt/LLMs-from-scratch

Bonus Code for Chapter 5#

Alternative Weight Loading from PyTorch state dicts#

  • In the main chapter, we loaded the GPT model weights directly from OpenAI

  • This notebook provides alternative weight loading code to load the model weights from PyTorch state dict files that I created from the original TensorFlow files and uploaded to the Hugging Face Model Hub at https://huggingface.co/rasbt/gpt2-from-scratch-pytorch

  • This is conceptually the same as loading weights of a PyTorch model from via the state-dict method described in chapter 5:

state_dict = torch.load("model_state_dict.pth")
model.load_state_dict(state_dict) 

Choose model#

from importlib.metadata import version

pkgs = ["torch"]
for p in pkgs:
    print(f"{p} version: {version(p)}")
---------------------------------------------------------------------------
PackageNotFoundError                      Traceback (most recent call last)
Cell In[1], line 5
      3 pkgs = ["torch"]
      4 for p in pkgs:
----> 5     print(f"{p} version: {version(p)}")

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:946, in version(distribution_name)
    939 def version(distribution_name):
    940     """Get the version string for the named package.
    941 
    942     :param distribution_name: The name of the distribution package to query.
    943     :return: The version string for the package as defined in the package's
    944         "Version" metadata key.
    945     """
--> 946     return distribution(distribution_name).version

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:919, in distribution(distribution_name)
    913 def distribution(distribution_name):
    914     """Get the ``Distribution`` instance for the named package.
    915 
    916     :param distribution_name: The name of the distribution package as a string.
    917     :return: A ``Distribution`` instance (or subclass thereof).
    918     """
--> 919     return Distribution.from_name(distribution_name)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:518, in Distribution.from_name(cls, name)
    516         return dist
    517 else:
--> 518     raise PackageNotFoundError(name)

PackageNotFoundError: No package metadata was found for torch
BASE_CONFIG = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "drop_rate": 0.0,       # Dropout rate
    "qkv_bias": True        # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}


CHOOSE_MODEL = "gpt2-small (124M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

Download file#

file_name = "gpt2-small-124M.pth"
# file_name = "gpt2-medium-355M.pth"
# file_name = "gpt2-large-774M.pth"
# file_name = "gpt2-xl-1558M.pth"
import os
import urllib.request

url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"

if not os.path.exists(file_name):
    urllib.request.urlretrieve(url, file_name)
    print(f"Downloaded to {file_name}")
Downloaded to gpt2-small-124M.pth

Load weights#

import torch
from llms_from_scratch.ch04 import GPTModel
# For llms_from_scratch installation instructions, see:
# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg


gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(torch.load(file_name, weights_only=True))
gpt.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt.to(device);

Generate text#

import tiktoken
from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text


torch.manual_seed(123)

tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate(
    model=gpt.to(device),
    idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
    max_new_tokens=30,
    context_size=BASE_CONFIG["context_length"],
    top_k=1,
    temperature=1.0
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
 Every effort moves forward, but it's not enough.

"I'm not going to sit here and say, 'I'm not going to do this,'

Alternative safetensors file#

  • In addition, the https://huggingface.co/rasbt/gpt2-from-scratch-pytorch repository contains so-called .safetensors versions of the state dicts

  • The appeal of .safetensors files lies in their secure design, as they only store tensor data and avoid the execution of potentially malicious code during loading

  • In newer versions of PyTorch (e.g., 2.0 and newer), a weights_only=True argument can be used with torch.load (e.g., torch.load("model_state_dict.pth", weights_only=True)) to improve safety by skipping the execution of code and loading only the weights (this is now enabled by default in PyTorch 2.6 and newer); so in that case loading the weights from the state dict files should not be a concern (anymore)

  • However, the code block below briefly shows how to load the model from these .safetensor files

file_name = "gpt2-small-124M.safetensors"
# file_name = "gpt2-medium-355M.safetensors"
# file_name = "gpt2-large-774M.safetensors"
# file_name = "gpt2-xl-1558M.safetensors"
import os
import urllib.request

url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"

if not os.path.exists(file_name):
    urllib.request.urlretrieve(url, file_name)
    print(f"Downloaded to {file_name}")
Downloaded to gpt2-small-124M.safetensors
# Load file

from safetensors.torch import load_file

gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(load_file(file_name))
gpt.eval();
token_ids = generate(
    model=gpt.to(device),
    idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
    max_new_tokens=30,
    context_size=BASE_CONFIG["context_length"],
    top_k=1,
    temperature=1.0
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
 Every effort moves forward, but it's not enough.

"I'm not going to sit here and say, 'I'm not going to do this,'