Save and Load Models (S&L)#

In the process of model development, we often encounter situations where we need to save (Save) and load (Load) models, such as:

  • In order to avoid training interruption caused by force majeure, it is necessary to develop the good habit of saving the model every certain period of training (Epoch);

  • At the same time, if the training time is too long, the model may be overfitted on the training data set, so it is necessary to save multiple checkpoints and obtain the optimal result;

  • In some cases, we need to load the parameters and other required information of the pre-trained model, resume training or fine-tune…

The pickle module that comes with Python is encapsulated in MegEngine to implement binary serialization and deserialization of Python object structures (such as Module objects). The core interfaces that need to be known to us are megengine.save and megengine.load:

>>> megengine.save(model, PATH)
>>> model = megengine.load(PATH)

The above syntax is very concise and intuitive to save and load the entire model model, but it is not recommended. A more recommended approach is to save and load state_dict objects, or use checkpointing techniques. The following will explain the above in more detail, and provide some best practices for saving and loading models in some scenarios. You can skip the concepts you are already familiar with and jump directly to the desired use case code demonstration.

save/load entire model

Not recommended under any circumstances ❌

save/load model state dictionary

Suitable for inference ✅ Does not meet recovery training requirements 😅

save/load checkpoint

Suitable for inference or recovery training 💡

Export static graph models (Dump)

It is suitable for inference and pursues high-performance deployment 🚀

Note

When using the pickle module, the corresponding terms are also called pickling and unpickling.

The pickle module is compatible with the protocol

Since the data stream format protocol used by the pickle module may be different between different versions of Python, the MegEngine model saved in a higher version of Python may not be loaded in a lower version of Python. There are two solutions here:

  • When calling megengine.save, specify a more compatible version (such as version 4) through the parameter pickle_protocol;

  • Interfaces megengine.save and megengine.load both support passing in the pickle_module parameter to use the specified pickle module, such as installing and using pickle5 instead of the built-in Python pickle module:

    >>> import pickle5 as pickle
    

The pickle module is not safe!

  • A well-meaning person can execute arbitrary code when unpacked by constructing malicious pickle data;

  • Therefore, never unblock data from untrusted sources and data that may have been tampered with.

Below is the ConvNet model we used for example:

import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim

class ConvNet(M.Module):
   def __init__(self):
      super().__init__()
      self.conv1 = M.Conv2d(1, 10, 5)
      self.pool1 = M.MaxPool2d(2, 2)
      self.conv2 = M.Conv2d(10, 20, 5)
      self.pool2 = M.MaxPool2d(2, 2)
      self.fc1 = M.Linear(320, 50)
      self.fc2 = M.Linear(50, 10)

   def forward(self, input):
      x = self.pool1(F.relu(self.conv1(input)))
      x = self.pool2(F.relu(self.conv2(x)))
      x = F.flatten(x, 1)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))

      return x

model = ConvNet()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.1)

save/load entire model#

save:

>>> megengine.save(model, PATH)

load:

>>> model = megengine.load(PATH)
>>> model.eval()

Note

The reason we do not recommend using this method is due to the limitations of :itself. For a specific class, such as a ``ConvNet model class designed by the user, pickle does not save the model. will serialize the model class itself, but will instead bind the class to the path containing the source code for its definition, such as project/model.py. This path is required by pickle when loading the model . So if you refactor the project later in the development process (for example, rename model.py), it will cause the model loading step to fail.

Warning

If you still use this method to load the model and try to infer, remember to switch to evaluation mode by calling model.eval() first.

save/load model state dictionary#

save:

>>> megengine.save(model.state_dict(), PATH)

load:

>>> model = ConvNet()
>>> model.load_state_dict(megengine.load(PATH))
>>> model.eval()

When saving a model for inference only, the necessary processing is to save the learned parameters of the model. Rather than saving the entire model, it is recommended to save the model’s state dictionary state_dict, which will be more flexible when restoring the model later.

Warning

  • Compared to loading the entire model, the result obtained by megengine.load() is a state dictionary object, so it is necessary to further load the state dictionary into the model through the model.load_state_dict() method. model = megengine.load(PATH) cannot be used in ` Deserialize the state dictionary and pass it to the model.load_state_dict() method;

  • After loading the state dictionary successfully, remember to call model.eval() to switch the model to evaluation mode.

Note

通常我们约定使用 .pkl 文件扩展名保存模型,如 mge_checkpoint_xxx.pkl 形式。

注意 .pkl.mge 文件的区别

.mge 文件通常是 MegEngine 模型经过 Export serialized model file (Dump) 得到的文件,用于推理部署。

what is a state dictionary#

Due to the limitation of path impact when using pickle to directly save/load entire model, we need to consider using the native Python data structure to record the state information inside the model, which is convenient for serialization and Deserialize. In Module base class concept and interface introduction, we mentioned that each Module has a state dictionary member, which records the Tensor information inside the model (ie Parameter and Buffer members):

>>> for tensor in model.state_dict():
...     print(tensor, "\t", model.state_dict()[tensor].shape)
conv1.bias       (1, 10, 1, 1)
conv1.weight     (10, 1, 5, 5)
conv2.bias       (1, 20, 1, 1)
conv2.weight     (20, 10, 5, 5)
fc1.bias         (50,)
fc1.weight       (50, 320)
fc2.bias         (10,)
fc2.weight       (10, 50)

The state dictionary is a simple Python dictionary object, so it can be easily saved and loaded with the help of pickle.

Note

Each optimizer Optimzer also has a state dictionary, which contains information about the state of the optimizer, and the hyperparameters used; if there is a subsequent need to restore the model and continue training, just saving the model’s state dictionary is not enough — — We also need to save information such as the optimizer’s state dictionary, which is the “checkpoint” technique mentioned below.

See also

Further explanation about state dictionary: Module state dictionary / Optimizer state dictionary

save/load checkpoint#

save:

megengine.save({
                "epoch": epoch,
                "state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                ...
               }, PATH)

load:

model = ConvNet()
optimizer = optim.SGD()

checkpoint = megengine.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

model.eval()
# - or -
model.train()
  • The purpose of saving checkpoints is to be able to restore to the same state as the training time: need to restore not only Module state dictionary ,:ref:optimizer-state-dict. According to actual needs, you can also record the training achieved epoch and the latest loss information.

  • After the checkpoint is loaded, set the model to train or evaluation mode, depending on whether you want to continue training or use it for inference.

Warning

Saving a full checkpoint will take up more disk space than just saving the model’s state dictionary. So you don’t have to save checkpoints if you’re pretty sure you only need to do model inference in the future. Or set a different saving frequency, such as saving a state dictionary every 10 Epochs, and saving a full checkpoint every 100 Epochs, depending on your actual needs.

See also

Refer to how to save and load checkpoints in the official ResNet model:

official/vision/classification/resnet

The relevant interface can be found in train/test/inference.py.

Export static graph models#

In order to deploy the final trained model to the production environment, the last step of model development requires exporting a static graph:

from megengine import jit

model = ConvNet()
model.load_state_dict(megengine.load(PATH))
model.eval()

@jit.trace(symbolic=True, capture_as_const=True)
def infer_func(data, *, model):
    pred = model(data)
    pred_normalized = F.softmax(pred)
    return pred_normalized

data = megengine.Tensor(np.random.randn(1, 1, 28, 28))
output = infer_func(data, model=model)

infer_func.dump(PATH, arg_names=["data"])

See also

See: Export serialized model file (Dump) for a more specific explanation.