# -*- coding: utf-8 -*-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")## Copyright (c) 2014-2021 Megvii Inc. All rights reserved.## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.importcollectionsimportfnmatchimportitertoolsimportpickleimportrefromcollectionsimportOrderedDictfromtypingimportAny,Dict,List,Optional,Sequencefrom..coreimport_imperative_rtfrom..core._imperative_rtimportComputingGraph,SerializationMetadatafrom..core._trace_optionimportset_symbolic_shapeas_set_symbolic_shapefrom..core.tensorimportmegbrain_graphasGfrom..loggerimportget_loggerfrom.comp_graph_toolsimportget_dep_vars,get_opr_type,get_oprs_seqfrom.network_nodeimport(ConstOpBase,Host2DeviceCopy,ImmutableTensor,NetworkNode,OpNode,VarNode,str_to_mge_class,)logger=get_logger(__name__)
[文档]classNetwork:def__init__(self):self.input_vars=[]# input var of graphself._orig_inputs=[]self.output_vars=[]# output var of graphself._orig_outputs=[]self.all_oprs_map=OrderedDict()# _imperative_rt.graph.VarNode.id: VarNodeself.all_vars_map=(OrderedDict())# _imperative_rt.graph.OperatorNode.id: OpNodeself.graph=ComputingGraph()self._metadata=None@propertydefmetadata(self):r"""Load metadata as a dict."""ifnotself._metadata.is_valid:logger.info("metadata is not valid!")returnNoneret=dict()try:user_info=pickle.loads(self._metadata.user_info)except:# pylint: disable=bare-exceptlogger.warning("can't parse user info by pickle, so return the original bytes object!")user_info=self._metadata.user_inforet["user_info"]=user_inforet["graph_modified"]=self._metadata.graph_modifiedret["optimized_for_inference"]=self._metadata.optimized_for_inferenceifret["optimized_for_inference"]:ret.update(G.deserialize_infer_option(self._metadata.optimize_options))returnret
[文档]@classmethoddefload(cls,model_path:str,outspec:List[str]=None):r"""Loads a computing graph as a Network object. Args: model_path: file path of mge model. outspec: only load the subgraph with outspec as its endpoints. """self=cls()ret=G.load_graph(model_path)outputs,self._metadata=ret.output_vars_list,ret.metadataifoutspecisnotNone:output_spec=outspec.copy()all_vars=get_dep_vars(outputs)+outputsnew_outputs={}foriinall_vars:ifi.nameinoutput_spec:new_outputs[i.name]=ioutput_spec.remove(i.name)assertlen(output_spec)==0,"Can not find {} in this model".format(output_spec)outputs=[new_outputs[i]foriinoutspec]self._orig_outputs=outputsforxinself._orig_outputs:self.output_vars.append(self._get_var(x))self.add_dep_oprs()forxinself._orig_inputs:self.input_vars.append(self._get_var(x))self.graph=self._orig_outputs[0].graphreturnself
[文档]defoptimize_for_inference(self,dest_vars,**kwargs):r"""Applies optimize_for_inference pass for operator graph. Args: dest_vars: list of output vars in the operator graph Keyword Arguments: * enable_io16xc32 -- whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be changed to float16. * enable_ioc16 -- whether to use float16 for both I/O and computation precision. * enable_hwcd4 -- whether to use NHWCD4 data layout. This is faster on some OpenCL backend. * enable_nchw88 -- whether to use NCHW88 data layout, currently used in X86 AVX backend. * enable_nchw44 -- whether to use NCHW44 data layout, currently used in arm backend. * enable_nchw44_dot -- whether to use NCHW44_dot data layout, currently used in armv8.2+dotprod backend. * enable_nchw4 -- whether to use NCHW4 data layout, currently used in nvidia backend(based on cudnn). * enable_nchw32 -- whether to use NCHW32 data layout, currently used in nvidia backend with tensorcore(based on cudnn). * enable_chwn4 -- whether to use CHWN4 data layout, currently used in nvidia backend with tensorcore. * enable_nchw64 -- whether to use NCHW64 data layout, used for fast int4 support on Nvidia GPU. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty into one opr. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z input for inference on nvidia backend(this optimization pass will result in mismatch of the precision of output of training and inference) """ifnotisinstance(dest_vars,Sequence):dest_vars=[dest_vars]dest_vars=list(G.VarNode(var.var)forvarindest_vars)new_vars=G.optimize_for_inference(dest_vars,**kwargs)returnlist(self._get_var(var)forvarinnew_vars)
[文档]defdump(self,file,*,keep_var_name:int=1,keep_opr_name:bool=False,keep_param_name:bool=False,keep_opr_priority:bool=False,strip_info_file=None,append_json=False,optimize_for_inference=True,append=False,user_info:Any=None,enable_metadata=True,**kwargs):r"""Serializes graph to file. Args: file: output file, could be file object or filename. append: whether output is appended to ``file``. Only works when ``file`` is str. keep_var_name: level for keeping variable names: * 0: none of the names are kept * 1: (default)keep names of output vars * 2: keep names of all (output and internal) vars keep_opr_name: whether to keep operator names. keep_param_name: whether to keep param names, so param values can be easily manipulated after loading model keep_opr_priority: whether to keep priority setting for operators strip_info_file: a string for path or a file handler. if is not None, then the dump information for code strip would be written to ``strip_info_file`` append_json: will be check when `strip_info_file` is not None. if set true, the information for code strip will be append to strip_info_file. if set false, will rewrite strip_info_file optimize_for_inference: enbale optmizations, will skip all optimize options if this is False. Default: True user_info: any type object, which will be pickled to bytes. enable_metadata: whether to save metadata into output file. See more detials in :meth:`~.trace.dump`. """def_set_var_name(var):graph_var=G.VarNode(var.var)graph_var.name=var.namereturngraph_varself._compile()out=list(map(_set_var_name,self.output_vars))ifkwargs.pop("arg_names",False):logger.warning('"arg_names" is not supported in Network.dump, rename input vars directly')ifkwargs.pop("output_names",False):logger.warning('"output_names" is not supported in Network.dump, rename output vars directly')ifoptimize_for_inference:out,optimize_options=G.optimize_for_inference(out,**kwargs)metadata=SerializationMetadata()ifenable_metadata:metadata.is_valid=Truemetadata.graph_modified=Truemetadata.user_info=pickle.dumps(user_info)ifoptimize_for_inference:metadata.optimize_options=optimize_optionsG.set_priority_to_id([o._nodeifisinstance(o,G.VarNode)elseoforoinout])dump_content,_=G.dump_graph(out,keep_var_name=keep_var_name,keep_opr_name=keep_opr_name,keep_param_name=keep_param_name,keep_opr_priority=keep_opr_priority,strip_info_file=strip_info_file,append_json=append_json,metadata=metadata,)ifisinstance(file,str):permission="wb"ifappend==Falseelse"ab"file=open(file,permission)file.write(dump_content)
[文档]defmake_const(self,data,name=None,device=None):r"""Makes an ImmutableTensor OpNode to provide a parameter for the network."""node=ImmutableTensor(data,name,device,self.graph)node.compile(self.graph)returnnode.outputs[0]
[文档]defmake_input_node(self,shape,dtype,name=None,device=None):r"""Makes a Host2DeviceCopy OpNode to provide an input varnode for the network."""node=Host2DeviceCopy(shape,dtype,name,device)node.compile(self.graph)returnnode.outputs[0]
[文档]defadd_output(self,*vars:VarNode):r"""Adds vars into the network output node list"""ifnotall([var.ownerforvarinvars]):self.add_dep_oprs(*vars)forvarinvars:# use method 'is' instead of 'in' to avoid# compare VarNode use elemwise equalifnotany(varis_for_inself.output_vars):self.output_vars.append(var)
[文档]defremove_output(self,*vars:VarNode):r"""Removes vars from the network output node list"""forvarinvars:# use list pop instead of remove to avoid# compare VarNode use elemwise equalforidx,out_varinenumerate(self.output_vars):ifvarisout_var:self.output_vars.pop(idx)
[文档]defadd_dep_oprs(self,*vars):iflen(vars)==0:vars=self.output_varsassertall(isinstance(var,VarNode)forvarinvars),"Only support add VarNode"q=list(vars)whilelen(q)>0:cur=q.pop(0)ifcur.ownerisnotNone:continueifcur.nameisNone:cur.name=cur.var.nameself.all_vars_map[cur.var.id]=curmge_opr=cur.var.ownerifget_opr_type(mge_opr)=="Host2DeviceCopy":self._orig_inputs.extend(mge_opr.outputs)cur.owner=self._add_opr(mge_opr)ifcur.ownerisNone:cur.owner=self.all_oprs_map[mge_opr.id]continueq.extend(cur.owner.inputs)returnlist(vars)
[文档]defmodify_opr_names(self,modifier):r"""Modifies names of operators **inplace**; useful for merging loaded network into another network Args: modifier(str or callable): a string to be prepended to the name, or a function that maps from name to name """ifisinstance(modifier,str):om=modifiermodifier=lambdav:"{}.{}".format(om,v)assertisinstance(modifier,collections.Callable)foriinself.all_oprs:v0=i.namev1=modifier(v0)assertisinstance(v1,str)i.name=v1
[文档]defreset_batch_size(self,batchsize,*,blacklist=()):r"""Helper for reset batch size; first dimension of all data providers not in blacklist are assumed to be the batch size Args: blacklist: data provider names whose first dimension is not batchbatch size """blacklist=set(blacklist)prev_batchsize=Noneforiinself.data_providers_filter:ifi.nameinblacklist:blacklist.remove(i.name)else:shp=list(i.shape)ifprev_batchsizeisNone:prev_batchsize=shp[0]else:assertprev_batchsize==shp[0],("batchsize mismatch: batchsize={} ""shape={} dp={}".format(prev_batchsize,shp,i.name))shp[0]=batchsizei.shape=tuple(shp)self._compile()assertprev_batchsizeisnotNone,"no data provider found"assertnotblacklist,"unused items in blacklist: {}".format(blacklist)
[文档]defreplace_vars(self,repl_dict:Dict[VarNode,VarNode]):r"""Replaces vars in the graph. Args: repl_dict: the map {old_var: new_var} that specifies how to replace the vars. """ifnotall([var.ownerforvarinrepl_dict.values()]):self.add_dep_oprs(*list(repl_dict.values()))forvarinself.all_vars:ifvarinrepl_dict:repl_var=repl_dict[var]ifrepl_varisvar:continueforopnodeinvar.users:# use method 'is' instead of 'in' to avoid# compare VarNode use elemwise equalassertany([varis_for_inopnode.inputs])opnode.inputs=[repl_varifvarisielseiforiinopnode.inputs]ifopnodenotinrepl_var.users:repl_var.users.append(opnode)var.users.clear()self._compile()
[文档]defreplace_oprs(self,repl_dict:Dict[OpNode,OpNode]):r"""Replaces operators in the graph. Args: repl_dict: the map {old_opr: new_opr} that specifies how to replace the operators. """foroprinself.all_oprs:ifoprinrepl_dict:assertlen(opr.outputs)==len(repl_dict[opr].outputs),"can not replace {} with {}".format(type(opr),type(repl_dict[opr]))forind,varinenumerate(opr.outputs):var.owner=repl_dict[opr]var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)var.var=repl_dict[opr].outputs[ind].varrepl_dict[opr].outputs=opr.outputsself._compile()
[文档]defget_opr_by_type(self,oprcls,unique=True):assertissubclass(oprcls,OpNode)rst=self.opr_filter.type(oprcls).as_list()ifunique:assertlen(rst)==1,"{} operators of type {} found".format(len(rst),oprcls)(rst,)=rstreturnrst
[文档]defget_opr_by_name(self,name,unique=True):rst=self.opr_filter.name(name).as_list()ifunique:assertlen(rst)==1,"{} operators of type {} found".format(len(rst),name)(rst,)=rstreturnrst
[文档]defget_var_by_name(self,name,unique=True):rst=self.var_filter.name(name).as_list()ifunique:assertlen(rst)==1,"{} operators of type {} found".format(len(rst),name)(rst,)=rstreturnrst
[文档]defget_var_receive_oprs(self,var):r"""Gets all oprs which use var as input"""returnself.opr_filter.has_input(var).as_list()
[文档]defget_dep_oprs(self,var):r"""Gets dependent oprs of var"""returnget_oprs_seq(var,False,False)
@propertydefopr_filter(self):r"""Filter on all opnodes of the Network."""oprs=self.all_oprsreturnNodeFilter(itertools.islice(oprs,len(oprs)))@propertydefvar_filter(self):r"""Filter on all varnode of the Network."""vars=self.all_varsreturnNodeFilter(itertools.islice(vars,len(vars)))@propertydefparams_filter(self):# all immutable tensorr"""Filter on all parameters (ImmutableTensor Opr) of the Network"""returnself.opr_filter.param_provider()@propertydefdata_providers_filter(self):# all host2devicecopyr"""Filter on all input nodes (Host2DeviceCopy Opr) of the Network"""returnself.opr_filter.data_provider()@propertydefdest_vars(self):r"""Output varnodes of the Network."""returnself.output_vars@propertydefall_oprs(self):returnget_oprs_seq(self.output_vars,False,False)@propertydefall_vars(self):returnget_dep_vars(self.output_vars)@propertydefall_vars_dict(self):returnself.var_filter.as_dict()@propertydefall_oprs_dict(self):returnself.opr_filter.as_dict()def_add_opr(self,opr)->Optional[OpNode]:r"""Used for loading and building graph."""assertisinstance(opr,_imperative_rt.graph.OperatorNode)# TODO: use megbrain C++ RTTI to replace type stringifopr.idnotinself.all_oprs_map:opnode=str_to_mge_class(get_opr_type(opr)).load(opr)self.all_oprs_map[opr.id]=opnodeforvarinopr.inputs:varnode=self._get_var(var)opnode.add_inp_var(varnode)varnode.users.append(opnode)forvarinopr.outputs:opnode.add_out_var(self._get_var(var))returnopnodeelse:# overwrite the opnode 'new' output VarNode with# original one when output number larger than 1,# or will cause dependence issue in _compiler step.iflen(opr.outputs)>1:opnode=self.all_oprs_map[opr.id]foridx,outputinenumerate(opnode.outputs):ifoutput.var.idinself.all_vars_map:opnode.outputs[idx]=self.all_vars_map[output.var.id]returnNonedef_get_opr(self,x):ifx.idinself.all_oprs_map:returnself.all_oprs_map[x.id]else:returnNonedef_get_var(self,x):r"""Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`."""assertisinstance(x,_imperative_rt.graph.VarNode)ifx.idnotinself.all_vars_maporself.all_vars_map[x.id].var!=x:self.all_vars_map[x.id]=VarNode.load(x,self._get_opr(x.owner))returnself.all_vars_map[x.id]
defset_symbolic_shape(option:bool):r"""Set the VarNode use symbolic shape or not, return the last status. Please set to True and must recover after dump if want to change the input batch size. Args: option: True for enable symbolic shape. """return_set_symbolic_shape(option)
[文档]defas_varnode(obj):r"""convert a :class:`.VarNode` compatible object to :class:`.VarNode`. Args: obj: it must be one of the following: 1. a :class:`.VarNode` object 2. a :class:`.OpNode` object that has unique output 3. an iterable that produces either type 1 or 2, with length 1 """iftype(obj)isVarNode:returnobjifisinstance(obj,OpNode):assertlen(obj.outputs)==1,("operator {} must have one output to be converted to VarNode; ""got {} actually".format(obj,len(obj.outputs)))ret=obj.outputs[0]asserttype(ret)isVarNodereturnretassertisinstance(obj,collections.Iterable),"{} is not compatible with VarNode".format(obj)val=list(obj)assert(len(val)==1),"can not convert sequence of length {} to VarNode ({})".format(len(val),(lambdas:siflen(s)<50elses[:50]+" ...")(str(val)))returnas_varnode(val[0])
[文档]defas_oprnode(obj):r"""convert a :class:`.OpNode` compatible object to :class:`.OpNode`; it works like :func:`as_varnode`.i """iftype(obj)isVarNode:returnobj.ownerifisinstance(obj,OpNode):returnobjassertisinstance(obj,collections.Iterable),"{} is not compatible with OpNode".format(obj)val=list(obj)assert(len(val)==1),"can not convert sequence of length {} to ""OpNode({})".format(len(val),val)returnas_oprnode(val[0])
[文档]classNodeFilter:r"""Filter on node iterator. This class is an iterator of :class:`.NetworkNode` objects and multiple filtering conditions and mappers can be chained. Example: .. code-block:: # find all :class:`.ImmutableTensor` nodes for i in NodeFilter(node_iter).param_provider(): print(i) # find all :class:`.ImmutableTensor` nodes that end with ':W' for i in NodeFilter(node_iter).param_provider().name('*:W'): print(i) # number of inputs nr_input = NodeFilter(node_iter).data_provider().as_count() """_iter=Nonedef__init__(self,node_iter):""" :param node_iter: iterator to :class:`.NetworkNode`, or a :class:`.VarNode`-compatible object; in the later case, its dependent oprs would be used """ifisinstance(node_iter,VarNode):oprs=get_oprs_seq(node_iter,False,False)node_iter=itertools.islice(oprs,len(oprs)-1)ifisinstance(node_iter,OpNode):oprs=get_oprs_seq(node_iter.inputs,False,False)node_iter=itertools.islice(oprs,len(oprs)-1)assertisinstance(node_iter,collections.Iterable)if(notisinstance(node_iter,NodeFilter))andtype(self)isnotNodeFilterCheckType:node_iter=NodeFilterCheckType(node_iter,NetworkNode)self._iter=node_iter
[文档]@classmethoddefmake_all_deps(cls,*dest_vars):r"""make a :class:`NodeFilter` that contains all deps of given vars"""returncls(list(get_oprs_seq(dest_vars,False,False)))
def__iter__(self):r"""to be overwritten by subclass to implement filters"""returniter(self._iter)
[文档]deftype(self,node_type):r"""filter by specific node type Args: node_type: node type class Returns: a new :class:`NodeFilter` object """returnNodeFilterType(self,node_type)
[文档]defcheck_type(self,node_type):r"""assert that all oprs produced by this iterator are instances of certain type Args: node_type: node type class Returns: a new :class:`NodeFilter` object Raises: TypeError if type check failed """returnNodeFilterCheckType(self,node_type)
[文档]defnot_type(self,node_type):r"""remove oprs of specific type Args: node_type: node type class Returns: a new :class:`NodeFilter` object """returnNodeFilterNotType(self,node_type)
[文档]defparam_provider(self):r"""get :class:`~.ParamProvider` oprs; shorthand for ``.type(ParamProvider)`` """returnself.type(ImmutableTensor)
[文档]defdata_provider(self):r"""get :class:`.DataProvider` oprs; shorthand for ``.type(DataProvider)`` """returnself.type(Host2DeviceCopy)
[文档]defname(self,pattern,ignorecase=True):r"""filter by node name Args: pattern(class:`str`): a string in glob syntax that can contain ``?`` and ``*`` to match a single or arbitrary characters. ignorecase(bool, optional): whether to ignroe case Returns: a new :class:`NodeFilter` object """returnNodeFilterName(self,pattern,ignorecase)
[文档]defhas_input(self,var):r"""an opr is kept if it has given var as one of its inputs Args: var: var node to checked Returns: a new :class:`NodeFilter` object """returnNodeFilterHasInput(self,var)
[文档]defas_list(self):r"""consume this iterator and return its content as a list"""returnlist(self)
[文档]defas_unique(self):r"""assert that this iterator yields only one node and return it Returns: class:`.GraphNodeBase`: the unique node Raises: ValueError if this iterator does not yield a unique node """(opr,)=selfreturnopr
[文档]defas_dict(self):r"""construct an ordered dict to map from node names to objects in this iterator """returncollections.OrderedDict((i.name,i)foriinself)
[文档]defas_count(self):r"""consume this iterator and get the number of elements"""returnsum(1for_inself)
[文档]classNodeFilterCheckType(NodeFilterType):r"""see :meth:`NodeFilter.check_type`"""def__iter__(self):foriinself._iter:ifnotisinstance(i,self._node_type):raiseTypeError("all nodes should be {}; got {!r}".format(self._node_type,i))yieldi
[文档]classNodeFilterHasInput(NodeFilter):r"""see :meth:`NodeFilter.has_input`"""_var=Nonedef__init__(self,node_iter,var):var=as_varnode(var)super().__init__(node_iter)self.var=vardef__iter__(self):foriinself._iter:assertisinstance(i,OpNode),"has_input() must be used with OpNode; ""got {!r}".format(i)ifany(self.varis_for_ini.inputs):yieldi