Note
Go to the end to download the full example code.
Creating, Saving, and Loading Models¶
Core concepts¶
Before we begin, let’s understand the two key classes:
1. TorchModuleCfg
: This is a specialized configuration class designed to define all the necessary parameters for a PyTorch nn.Module
,
such as layer counts, dimensions, activation functions, etc. Your model will be instantiated from it.
2. ModelMixin
: This is a mixin class. Your custom model class should inherit from it.
It provides the essential save_model
and load_model
methods.
Step 1: Define Your Model and Configuration¶
First, we need to define the model’s architecture and its corresponding configuration.
Let’s create a simple fully-connected network called SimpleNet
.
Create the config class
SimpleNetCfg
: This class must inherit fromTorchModuleCfg
and should define the parameters your model needs (e.g., input_size, hidden_size, output_size).Create the model class
SimpleNet
: This class must inherit fromModelMixin
. In its__init__
method, it accepts a config objectcfg
and callssuper().__init__(cfg)
to complete the setup.
import torch.nn as nn
from robo_orchard_lab.models.mixin import (
ClassType_co,
ModelMixin,
TorchModuleCfg,
)
# 1. Define the model class, inheriting from ModelMixin
class SimpleNet(ModelMixin):
def __init__(self, cfg: "SimpleNetCfg"):
# It's crucial to call super().__init__ and pass the cfg
super().__init__(cfg)
self.fc1 = nn.Linear(cfg.input_size, cfg.hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(cfg.hidden_size, cfg.output_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 2. Define the configuration class for the model
class SimpleNetCfg(TorchModuleCfg[SimpleNet]):
class_type: ClassType_co[SimpleNet] = SimpleNet
input_size: int = 784
hidden_size: int = 128
output_size: int = 10
Step 2: Instantiate and Save the Model¶
The ModelMixin
provides the save_model
method,
which automatically performs two actions:
Saves the model’s configuration
cfg
tomodel.config.json
.Saves the model’s weights
state_dict
tomodel.safetensors
.
import os
import shutil
config = SimpleNetCfg(hidden_size=256)
Instantiate the model by calling the config object. This leverages the functionality of ClassInitFromConfigMixin
model = config()
print("Model created:", model)
Model created: SimpleNet(
(fc1): Linear(in_features=784, out_features=256, bias=True)
(relu): ReLU()
(fc2): Linear(in_features=256, out_features=10, bias=True)
)
Call the save_model method
output_dir = "./checkpoint"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
model.save_model(output_dir)
import subprocess
print(f"Model has been saved to the `{output_dir}` directory.")
print(subprocess.check_output(["tree", output_dir]).decode())
Model has been saved to the `./checkpoint` directory.
./checkpoint
├── model.config.json
└── model.safetensors
0 directories, 2 files
Step 3: Load the Model¶
Loading the model is just as easy. The load_model
method automatically reads model.config.json
to build the model architecture and then loads the weights from model.safetensors
.
loaded_model = ModelMixin.load_model(output_dir)
print("Model loaded:", loaded_model)
Model loaded: SimpleNet(
(fc1): Linear(in_features=784, out_features=256, bias=True)
(relu): ReLU()
(fc2): Linear(in_features=256, out_features=10, bias=True)
)
Integrating with Hugging Face Accelerator¶
ModelMixin
provides a pre-built hook, accelerator_save_state_pre_hook
,
for seamless integration with the 🤗 Accelerate training library.
from accelerate import Accelerator
accelerator = Accelerator()
accelerator.register_save_state_pre_hook(
SimpleNet.accelerator_save_state_pre_hook
)
<torch.utils.hooks.RemovableHandle object at 0x7fb55f82eb60>
Total running time of the script: (0 minutes 0.020 seconds)