megengine.quantization.LSQ.load_state_dict¶
- LSQ.load_state_dict(state_dict, strict=True)¶
- Loads a given dictionary created by - state_dictinto this module. If- strictis- True, the keys of- state_dictmust exactly match the keys returned by- state_dict.- Users can also pass a closure: - Function[key: str, var: Tensor] -> Optional[np.ndarray]as a state_dict, in order to handle complex situations. For example, load everything except for the final linear classifier:- state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False) - Here returning - Nonemeans skipping parameter- k.- To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: - state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly) - We can also perform inplace re-initialization or pruning: - def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k: