megengine.module.SlidingWindow¶
- class SlidingWindow(kernel_size, padding=0, stride=1, dilation=1, **kwargs)[源代码]¶
- Apply a sliding window to input tensor and copy content in the window to corresponding output location. Assume input shape is \((N, C, IH, IW)\), then output shape would be \((N, C, OH, OW, window_h, window_w)\) where \((OH, OW)\) would be computed from padding, stride, window and \((IH, IW)\), as in convolution. For each output location, we have; \[\begin{split}out_{n, c, oh, ow, wh, ww} &= src_{n, c, ih+wh, iw+ww} \\ \text{where } & ih=-pad_h+oh \times stride_h + (wh-1) \times (dilation_h-1) \\ & iw=-pad_w+ow \times stride_w + (ww-1) \times (dilation_w-1)\end{split}\]- 参数
 - 示例 - from megengine import tensor import megengine.module as M import numpy as np inp = tensor(np.arange(30).reshape(1,1,5,6)) op = M.SlidingWindow(kernel_size=3, padding=1, stride=2, dilation=2) out = op(inp) print(out.numpy()) - 输出: - [[[[[[ 0 0 0] [ 0 7 9] [ 0 19 21]] [[ 0 0 0] [ 7 9 11] [19 21 23]]] [[[ 0 7 9] [ 0 19 21] [ 0 0 0]] [[ 7 9 11] [19 21 23] [ 0 0 0]]]]]]- 方法 - apply(fn)- 对当前模块中的所有模块应用函数 - fn,包括当前模块本身。- buffers([recursive])- 返回该模块中对于buffers的一个可迭代对象。 - children(**kwargs)- 返回一个可迭代对象,可遍历所有属于当前模块的直接属性的子模块。 - disable_quantize([value])- 设置 - module的- quantize_diabled属性,并返回- module。- eval()- 当前模块中所有模块的 - training属性(包括自身)置为- False,并将其切换为推理模式。- forward(inp)- load_state_dict(state_dict[, strict])- 加载一个参数字典,这个字典通常使用 - state_dict得到。- modules(**kwargs)- 返回一个可迭代对象,可以遍历当前模块中的所有模块,包括其本身。 - named_buffers([prefix, recursive])- 返回可遍历模块中 key 与 buffer 的键值对的可迭代对象,其中 - key为从该模块至 buffer 的点路径(dotted path)。- named_children(**kwargs)- 返回可迭代对象,可以遍历属于当前模块的直接属性的所有子模块(submodule)与键(key)组成的”key-submodule”对,其中'key'是子模块对应的属性名。 - named_modules([prefix])- 返回可迭代对象,可以遍历当前模块包括自身在内的所有其内部模块所组成的key-module键-模块对,其中'key'是从当前模块到各子模块的点路径(dotted path)。 - named_parameters([prefix, recursive])- 返回一个可迭代对象,可以遍历当前模块中key与 - Parameter组成的键值对。其中- key是从模块到- Parameter的点路径(dotted path)。- named_tensors([prefix, recursive])- Returns an iterable for key tensor pairs of the module, where - keyis the dotted path from this module to the tensor.- parameters([recursive])- 返回一个可迭代对象,遍历当前模块中的所有 - Parameter- register_forward_hook(hook)- 给模块输出注册一个回调函数。 - 给模块输入注册一个回调函数。 - replace_param(params, start_pos[, seen])- Replaces module's parameters with - params, used by- ParamPackto- state_dict([rst, prefix, keep_var])- tensors([recursive])- Returns an iterable for the - Tensorof the module.- train([mode, recursive])- 当前模块中所有模块的 - training属性(包括自身)置为- mode。- 将所有参数的梯度置0。