Module#

class Module(name=None)[source]#

Base Module class.

Parameters:

name – module’s name, can be initialized by the kwargs parameter of child class.

apply(fn)[source]#

Applies function fn to all the modules within this module, including itself.

Parameters:

fn (Callable[[Module], Any]) – the function to be applied on modules.

Return type:

None

buffers(recursive=True, **kwargs)[source]#

Returns an iterable for the buffers of the module.

Buffer is defined to be Tensor excluding Parameter.

Parameters:

recursive (bool) – if True, returns all buffers within this module, else only returns buffers that are direct attributes

Return type:

Iterable[Tensor]

children(**kwargs)[source]#

Returns an iterable for all the submodules that are direct attributes of this module.

Return type:

Iterable[Module]

disable_quantize(value=True)[source]#

Sets module’s quantize_disabled attribute and return module. Could be used as a decorator.

eval()[source]#

Sets training mode of all the modules within this module (including itself) to False. See train for details.

Return type:

None

load_state_dict(state_dict, strict=True)[source]#

Loads a given dictionary created by state_dict into this module. If strict is True, the keys of state_dict must exactly match the keys returned by state_dict.

Users can also pass a closure: Function[key: str, var: Tensor] -> Optional[np.ndarray] as a state_dict, in order to handle complex situations. For example, load everything except for the final linear classifier:

state_dict = {...}  #  Dict[str, np.ndarray]
model.load_state_dict({
    k: None if k.startswith('fc') else v
    for k, v in state_dict.items()
}, strict=False)

Here returning None means skipping parameter k.

To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:

state_dict = {...}
def reshape_accordingly(k, v):
    return state_dict[k].reshape(v.shape)
model.load_state_dict(reshape_accordingly)

We can also perform inplace re-initialization or pruning:

def reinit_and_pruning(k, v):
    if 'bias' in k:
        M.init.zero_(v)
    if 'conv' in k:
modules(**kwargs)[source]#

Returns an iterable for all the modules within this module, including itself.

Return type:

Iterable[Module]

named_buffers(prefix=None, recursive=True, **kwargs)[source]#

Returns an iterable for key buffer pairs of the module, where key is the dotted path from this module to the buffer.

Buffer is defined to be Tensor excluding Parameter.

Parameters:
  • prefix (Optional[str]) – prefix prepended to the keys.

  • recursive (bool) – if True, returns all buffers within this module, else only returns buffers that are direct attributes of this module.

  • prefix – Optional[str]:

Return type:

Iterable[Tuple[str, Tensor]]

named_children(**kwargs)[source]#

Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where ‘key’ is the attribute name of submodules.

Return type:

Iterable[Tuple[str, Module]]

named_modules(prefix=None, **kwargs)[source]#

Returns an iterable of key-module pairs for all the modules within this module, including itself, where ‘key’ is the dotted path from this module to the submodules.

Parameters:

prefix (Optional[str]) – prefix prepended to the path.

Return type:

Iterable[Tuple[str, Module]]

named_parameters(prefix=None, recursive=True, **kwargs)[source]#

Returns an iterable for key Parameter pairs of the module, where key is the dotted path from this module to the Parameter.

Parameters:
  • prefix (Optional[str]) – prefix prepended to the keys.

  • recursive (bool) – if True, returns all Parameter within this module, else only returns Parameter that are direct attributes of this module.

Return type:

Iterable[Tuple[str, Parameter]]

named_tensors(prefix=None, recursive=True, **kwargs)[source]#

Returns an iterable for key tensor pairs of the module, where key is the dotted path from this module to the tensor.

Parameters:
  • prefix (Optional[str]) – prefix prepended to the keys.

  • recursive (bool) – if True, returns all tensors within this module, else only returns tensors that are direct attributes of this module.

Return type:

Iterable[Tuple[str, Tensor]]

parameters(recursive=True, **kwargs)[source]#

Returns an iterable for the Parameter of the module.

Parameters:

recursive (bool) – If True, returns all Parameter within this module, else only returns Parameter that are direct attributes of this module.

Return type:

Iterable[Parameter]

register_forward_hook(hook)[source]#

Registers a hook to handle forward results. hook should be a function that receive module, inputs and outputs, then return a modified outputs or None.

This method return a handler with remove interface to delete the hook.

Return type:

HookHandler

register_forward_pre_hook(hook)[source]#

Registers a hook to handle forward inputs. hook should be a function.

Parameters:

hook (Callable) – a function that receive module and inputs, then return a modified inputs or None.

Return type:

HookHandler

Returns:

a handler with remove interface to delete the hook.

replace_param(params, start_pos, seen=None)[source]#

Replaces module’s parameters with params, used by ParamPack to speedup multimachine training.

Deprecated since version 1.0.

state_dict(rst=None, prefix='', keep_var=False)[source]#

Returns a dictionary containing whole states of the module.

tensors(recursive=True, **kwargs)[source]#

Returns an iterable for the Tensor of the module.

Parameters:

recursive (bool) – If True, returns all Tensor within this module, else only returns Tensor that are direct attributes of this module.

Return type:

Iterable[Parameter]

train(mode=True, recursive=True)[source]#

Sets training mode of all the modules within this module (including itself) to mode. This effectively sets the training attributes of those modules to mode, but only has effect on certain modules (e.g. BatchNorm2d, Dropout, Observer)

Parameters:
  • mode (bool) – the training mode to be set on modules.

  • recursive (bool) – whether to recursively call submodules’ train().

Return type:

None

zero_grad()[source]#

Sets all parameters’ grads to zero

Deprecated since version 1.0.

Return type:

None