Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka Code repository: https://github.com/rasbt/LLMs-from-scratch |
![]() |
Memory-efficient Model Weight Loading#
This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited
Specifically, it focuses on cases where you saved the model using
torch.save(model.state_dict(), "model.pth")
(for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuningWhile the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs

from importlib.metadata import version
pkgs = [
"torch",
]
for p in pkgs:
print(f"{p} version: {version(p)}")
---------------------------------------------------------------------------
PackageNotFoundError Traceback (most recent call last)
Cell In[1], line 7
3 pkgs = [
4 "torch",
5 ]
6 for p in pkgs:
----> 7 print(f"{p} version: {version(p)}")
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:946, in version(distribution_name)
939 def version(distribution_name):
940 """Get the version string for the named package.
941
942 :param distribution_name: The name of the distribution package to query.
943 :return: The version string for the package as defined in the package's
944 "Version" metadata key.
945 """
--> 946 return distribution(distribution_name).version
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:919, in distribution(distribution_name)
913 def distribution(distribution_name):
914 """Get the ``Distribution`` instance for the named package.
915
916 :param distribution_name: The name of the distribution package as a string.
917 :return: A ``Distribution`` instance (or subclass thereof).
918 """
--> 919 return Distribution.from_name(distribution_name)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/importlib/metadata/__init__.py:518, in Distribution.from_name(cls, name)
516 return dist
517 else:
--> 518 raise PackageNotFoundError(name)
PackageNotFoundError: No package metadata was found for torch
1. Benchmark utilities#
First, let’s define some utility code to track VRAM (GPU memory)
Later, we will also introduce a tool to track the main system RAM (CPU memory)
The purpose of these functions will become clear when we apply them later
import gc
import time
import torch
def start_memory_tracking():
"""Initialize GPU memory tracking."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
else:
print("This notebook is intended for CUDA GPUs but CUDA is not available.")
def print_memory_usage():
max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # Convert bytes to GB
print(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB")
def cleanup():
gc.collect()
torch.cuda.empty_cache()
time.sleep(3) # some buffer time to allow memory to clear
torch.cuda.reset_peak_memory_stats()
max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
print(f"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB")
2. Model setup#
This code section sets up the model itself
Here, we use the “large” GPT-2 model to make things more interesting (you may use the “gpt2-small (124M)” to lower the memory requirements and execution time of this notebook)
from previous_chapters import GPTModel
# If the `previous_chapters.py` file is not available locally,
# you can import it from the `llms-from-scratch` PyPI package.
# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
# E.g.,
# from llms_from_scratch.ch04 import GPTModel
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
CHOOSE_MODEL = "gpt2-xl (1558M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
Now, let’s see the GPU memory functions in action:
start_memory_tracking()
model = GPTModel(BASE_CONFIG)
device = torch.device("cuda")
model.to(device)
print_memory_usage()
Maximum GPU memory allocated: 6.4 GB
Additionally, let’s make sure that the model runs okay by passing in some example tensor
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()
with torch.no_grad():
model(test_input)
Next, imagine we were pretraining the model and saving it for later use
We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)
# Training code would go here...
model.train()
torch.save(model.state_dict(), "model.pth")
Lastly, we delete the model and example tensor in the Python session to reset the GPU memory
del model, test_input
cleanup()
Maximum GPU memory allocated: 0.0 GB
3. Weight loading#
Now begins the interesting part where we load the pretrained model weights
Let’s see how much GPU memory is required to load the previously saved model
# Then load pretrained weights
start_memory_tracking()
model = GPTModel(BASE_CONFIG)
model.to(device)
model.load_state_dict(
torch.load("model.pth", map_location=device, weights_only=True)
)
model.to(device)
model.eval();
print_memory_usage()
Maximum GPU memory allocated: 12.8 GB
Notice that the memory is 2x as large as in the previous session
This is because we have the same model in memory twice, for a short period of time:
The first time via
model.to(device)
The second time via the code line
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
; eventually, the loaded model weights will be copied into the model, and thestate_dict
will be discarded, but for a brief amount of time, we have both the main model and the loadedstate_dict
in memory
The remaining sections focus on addressing this
But first, let’s test the model and reset the GPU memory
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()
with torch.no_grad():
model(test_input)
del model, test_input
cleanup()
Maximum GPU memory allocated: 0.0 GB
4. Loading weights sequentially#
One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially
Below, we:
first load the model into GPU memory
then load the model weights into CPU memory
and finally copy each parameter one by one into GPU memory
start_memory_tracking()
model = GPTModel(BASE_CONFIG).to(device)
state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)
print_memory_usage()
# Sequentially copy weights to the model's parameters
with torch.no_grad():
for name, param in model.named_parameters():
if name in state_dict:
param.copy_(state_dict[name].to(device))
else:
print(f"Warning: {name} not found in state_dict.")
print_memory_usage()
Maximum GPU memory allocated: 6.4 GB
Maximum GPU memory allocated: 6.7 GB
As we can see above, the memory usage is much lower than before
Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using
".to"
the model)Overall, this is a significant improvement
Again, let’s briefly test the model and then reset the GPU memory for the next section
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()
with torch.no_grad():
model(test_input)
del model, test_input, state_dict, param
cleanup()
Maximum GPU memory allocated: 0.0 GB
5. Loading the model with low CPU memory#
In the previous session, we reduced GPU memory use by loading the weights (
state_dict
) into CPU memory first before copying them one-by-one into the modelHowever, what do we do if we have limited CPU memory?
This section uses PyTorch’s so-called
"meta"
device approach to load a model on machines with large GPU memory but small CPU memoryBut first, let’s define a convenience function to monitor CPU memory
import os
import psutil
from threading import Thread
def memory_usage_in_gb(func, *args, **kwargs):
process = psutil.Process(os.getpid())
# Measure the baseline memory usage before running the function
baseline_mem = process.memory_info().rss / 1024 ** 3 # in GB
# Start monitoring memory in a separate thread
mem_usage = []
done = False
def monitor_memory():
while not done:
mem_usage.append(process.memory_info().rss / 1024 ** 3) # Convert to GB
time.sleep(0.1)
t = Thread(target=monitor_memory)
t.start()
# Run the function
func(*args, **kwargs)
# Stop monitoring
done = True
t.join()
peak_mem_usage_gb = max(mem_usage) - baseline_mem
return peak_mem_usage_gb
To start with, let’s track the CPU memory of the sequential weight loading approach from the previous section
def load_sequentially():
start_memory_tracking()
model = GPTModel(BASE_CONFIG).to(device)
state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)
print_memory_usage()
# Sequentially copy weights to the model's parameters
with torch.no_grad():
for name, param in model.named_parameters():
if name in state_dict:
param.copy_(state_dict[name].to(device))
else:
print(f"Warning: {name} not found in state_dict.")
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_sequentially)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 6.4 GB
Maximum GPU memory allocated: 6.7 GB
-> Maximum CPU memory allocated: 6.3 GB
Now, suppose we have a machine with low CPU memory but large GPU memory
We can trade off CPU memory and GPU memory usage by introducing PyTorch’s so-called “meta” device
PyTorch’s meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating “meta” tensors
This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation
def load_sequentially_with_meta():
start_memory_tracking()
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model = model.to_empty(device=device)
state_dict = torch.load("model.pth", map_location=device, weights_only=True)
print_memory_usage()
# Sequentially copy weights to the model's parameters
with torch.no_grad():
for name, param in model.named_parameters():
if name in state_dict:
param.copy_(state_dict[name])
else:
print(f"Warning: {name} not found in state_dict.")
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 12.8 GB
Maximum GPU memory allocated: 12.8 GB
-> Maximum CPU memory allocated: 1.3 GB
As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements
One might ask: “Is the sequential weight loading still necessary then, and how does that compare to the original approach?”
Let’s check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):
def baseline():
start_memory_tracking()
model = GPTModel(BASE_CONFIG)
model.to(device)
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
model.to(device)
model.eval();
print_memory_usage()
peak_memory_used = memory_usage_in_gb(baseline)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 12.8 GB
-> Maximum CPU memory allocated: 4.4 GB
As we can see above, the “simple” weight loading without the meta device uses more memory
In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage
6. Using mmap=True
(recommmended)#
As an intermediate or advanced
torch.load
user, you may wonder how these approaches compare to themmap=True
setting in PyTorchThe
mmap=True
setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limitedAlso, see the helpful comment by mikaylagawarecki
At first glance, it may look less efficient than the sequential approaches above:
def best_practices():
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model.load_state_dict(
torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
assign=True
)
print_memory_usage()
peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 5.9 GB
The reason why the CPU RAM usage is so high is that there’s enough CPU RAM available on this machine
However, if you were to run this on a machine with limited CPU RAM, the
mmap
approach would use less memory
7. Other methods#
This notebook is focused on simple, built-in methods for loading weights in PyTorch
The recommended approach for limited CPU memory cases is the
mmap=True
approach explained enoughAlternatively, one other option is a brute-force approach that saves and loads each weight tensor separately:
model = GPTModel(BASE_CONFIG)
# Assume `model` is your trained model
state_dict = model.state_dict()
# Create a directory to store individual parameter files
os.makedirs("model_parameters", exist_ok=True)
# Save each parameter tensor separately
for name, param in state_dict.items():
torch.save(param.cpu(), f"model_parameters/{name}.pt")
del model
def load_individual_weights():
start_memory_tracking()
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model = model.to_empty(device=device)
print_memory_usage()
param_dir = "model_parameters"
with torch.no_grad():
for name, param in model.named_parameters():
weight_path = os.path.join(param_dir, f"{name}.pt")
if os.path.exists(weight_path):
param_data = torch.load(weight_path, map_location="cpu", weights_only=True)
param.copy_(param_data)
del param_data # Free memory
else:
print(f"Warning: {name} not found in {param_dir}.")
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_individual_weights)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Maximum GPU memory allocated: 6.4 GB
Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 0.3 GB