[文档]classOptimizer(metaclass=ABCMeta):r"""Base class for all optimizers. Args: params: specifies what Tensors should be optimized. defaults: a dict of default parameters of Optimizer, like learning rate or momentum. """def__init__(# pylint: disable=too-many-branchesself,params:Union[Iter[Parameter],dict],defaults:dict,):self._state=dict()self._defaults=defaultsself._disable_type_convert=Falseifisinstance(params,(Parameter,dict)):params=[params]else:ifnotisinstance(params,Iterable):raiseTypeError("params argument given to the optimizer should be ""Parameter or dict, or Iterable of them")self.param_groups=[]# type: listparam_groups=list(params)iflen(param_groups)==0:raiseValueError("optimizer got an empty parameter list")param_type=type(param_groups[0])forparaminparam_groups:ifnotisinstance(param,param_type):raiseTypeError("types of params argument given to the optimizer shoud be same")ifnotisinstance(param_groups[0],dict):param_groups=[{"params":param_groups}]forgroupinparam_groups:self.add_param_group(group)forgroupinself.param_groups:self._create_state(group)
[文档]defadd_param_group(self,param_group:dict):r"""Add a param group to ``param_groups`` of the :class:`~megengine.optim.optimizer.Optimizer`. This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`~megengine.optim.optimizer.Optimizer` as training progresses. Args: param_group: specifies what tensors should be optimized along with group. """assertisinstance(param_group,dict),"param group must be a dict"ifisinstance(param_group["params"],Parameter):param_group["params"]=[param_group["params"]]else:param_group["params"]=list(param_group["params"])forparaminparam_group["params"]:ifnotisinstance(param,Parameter):raiseTypeError("optimizer can only optimize Parameters, but one of the params is "+str(type(param)))param[...]=Tensor(param,no_cache=True)forname,defaultinself._defaults.items():ifdefaultisrequiredandnamenotinparam_group:raiseValueError("parameter group didn't specify a value of ""required optimization parameter "+name)param_group.setdefault(name,default)param_set=set()forgroupinself.param_groups:param_set.update(set(map(id,group["params"])))assertparam_set.isdisjoint(set(map(id,param_group["params"]))),"some parameters appear in more than one parameter group"self.param_groups.append(param_group)
[文档]defstep(self):r"""Performs a single optimization step."""# set the globle state `_enable_convert_inputs` to `False` to disable# the `convert_inputs` for param updatesset_option("record_computing_path",0)_origin_auto_format=get_auto_format_convert()set_auto_format_convert(False)ifself._disable_type_convert:backup=set_convert_inputs(False)forgroupinself.param_groups:ifisinstance(group["params"],set):raiseTypeError("optimized parameters need to be organized in ordered collections, ""but the ordering of parameters in sets will change between runs. ""Please use a list instead.")push_scope("step")self._updates(group)pop_scope("step")ifself._disable_type_convert:# restore the globle state `_enable_convert_inputs`set_convert_inputs(backup)set_option("record_computing_path",1)set_auto_format_convert(_origin_auto_format)returnself
[文档]defclear_grad(self):r"""Set the grad attribute to None for all parameters."""forparam_groupinself.param_groups:push_scope("clear_grad")forparaminparam_group["params"]:param.grad=Nonepop_scope("clear_grad")
[文档]defstate_dict(self,keep_var=False)->Dict:r"""Export the optimizer state. Return: optimizer state. Can be loaded by :meth:`load_state_dict`. """param_groups=[]state=dict()param2id=dict()cur_id=0forgroupinself.param_groups:forparamingroup["params"]:ifparamnotinparam2id:param2id[param]=cur_idcur_id+=1forparam,stinself._state.items():_st=copy.copy(st)ifnotkeep_var:fork,vinst.items():_st[k]=v.numpy()state[param2id[param]]=_stforgroupinself.param_groups:param_group={k:vfork,vingroup.items()ifk!="params"}param_group["params"]=[param2id[param]forparamingroup["params"]]param_groups.append(param_group)return{"param_groups":param_groups,"state":state}
[文档]defload_state_dict(self,state:dict):r"""Loads the optimizer state. Args: state: optimizer state. Should be an object returned from a call to :meth:`state_dict`. """iflen(self.param_groups)!=len(state["param_groups"]):raiseValueError("loaded state dict has a different number of parameter groups")forgroup_new,group_savedinzip(self.param_groups,state["param_groups"]):iflen(group_new["params"])!=len(group_saved["params"]):raiseValueError("loaded state dict contains a parameter group that ""doesn't match the size of optimizer's group")forparam_new,param_savedinzip(group_new["params"],group_saved["params"]):p=param_newself._state[p]=state["state"][param_saved].copy()fork,vinself._state[p].items():ifisinstance(v,Tensor):self._state[p][k]=v.detach()else:self._state[p][k]=Tensor(v)ifset(group_new.keys())!=set(group_saved.keys()):raiseValueError("loaded state dict contains a parameter group that ""doesn't match the keys of optimizer's group")forkeyingroup_new.keys():ifkey!="params":group_new[key]=group_saved[key]iflen(self._state.keys())!=len(state["state"].keys()):raiseValueError("loaded state dict contains a state that doesn't match ""the size of optimizer's state")