[文档]classFunction:r"""Defines a block of operations with customizable differentiation. The computation should be defined in ``forward`` method, with gradient computation defined in ``backward`` method. Each instance of ``Function`` should be used only once during forwardding. Examples: .. code-block:: class Sigmoid(Function): def forward(self, x): y = 1 / (1 + F.exp(-x)) self.y = y return y def backward(self, dy): y = self.y """
[文档]defforward(self,*args,**kwargs):r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. Args: input: input tensors. Returns: a tuple of Tensor or a single Tensor. Note: * This method should return a tuple of Tensor or a single Tensor representing the output of the function. * positional arguments should all be Tensor """raiseNotImplementedError
[文档]defbackward(self,*output_grads):r"""Compute the gradient of the forward function. It must be overriden by all subclasses. Args: output_grads: gradients of outputs that are returned by :meth:`forward`. Note: * In case when some tensors of outputs are not related to loss function, the corresponding values in ``output_grads`` would be ``None``. * This method should return a tuple which containing the gradients of all inputs, in the same order as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned instead if there is only one input. If users want to stop the propagation of some gradients, the corresponding returned values should be set ``None`` . """raiseNotImplementedError
def_default_rule(self,*args):ret=self.forward(*args)self.__single_output=isinstance(ret,core2.Tensor)returnretdef_grad_rule(self,*args):returnself._default_rule(*args),self.backwarddef__call__(self,*args):from...tensorimportTensorforarginargs:ifnotisinstance(arg,Tensor):raiseTypeError("op Function expect type Tensor as inputs, got {}".format(type(arg)))grad_key=core2.get_grad_key(args)ifgrad_keyisNone:returnself._default_rule(*args)grad=Grad.key2grad[grad_key]group=[ref()forrefingrad._group]origin_args=[Tensor(arg)forarginargs]forgradingroup:grad.suppress()outputs,backward=self._grad_rule(*args)forgradinreversed(group):grad.resume()defnormalized_backward(*output_grads):input_grads=backward(*output_grads)ifisinstance(input_grads,Tensor)orinput_gradsisNone:input_grads=(input_grads,)returninput_gradsifself.__single_output:outputs=(outputs,)outputs=core2.set_grad(normalized_backward,origin_args,outputs)ifself.__single_output:(outputs,)=outputsreturnoutputsdef__getstate__(self):returnself.__dict__def__setstate__(self,state):self.__dict__.update(state)