参数量/计算量统计与可视化

有时候我们经常需要统计模型的参数量和计算量,目前的实现方式有两种:

  • 基于 module 的实现——

    • 优点:可以在 Python 代码中嵌入调用,随时可以看统计信息

    • 缺点:只能统计 module 的信息,无法统计 functional 的调用

  • 基于 dump 的实现——

    • 优点:可以覆盖所有的算子

    • 缺点:需要先进行 dump 操作

基于 module 的统计

实现在 module_stats 中, 可以支持 float32/qat/qint8 模型的统计,使用方式很简单:

from megengine.hub import load
from megengine.utils.module_stats import module_stats

# 构建一个 net module,这里从 model hub 中获取 resnet18 模型
net = load("megengine/models", "resnet18", pretrained=True)

# 指定输入 shape
input_shape = (1, 3, 224, 224)

# Float model.
total_params, total_flops = module_stats(
    net, input_shape, log_params=True, log_flops=True
)
print("params {} flops {}".format(total_params, total_flops))

可以通过 log_paramslog_flops 参数来控制是否输出 parameter 和 flops 细节信息,返回总的参数量和计算量。

基于 dump 图的可视化与统计

基于 Python Graph 的图结构解析功能实现:

  • 输入 mge 格式的 dump 模型路径以及 log 存储目录

  • 可将图结构信息存成 TensorBoard 可读的格式。

命令行调用

python3 -m megengine.tools.network_visualize ./resnet18.mge ./log --log_flops --log_params

其中各个参数说明如下:

./resnet18.mge (第一个参数)

必填参数 ,指定模型文件名。

./log (第二个参数)

必填参数 ,指定 log 存储目录。

--log_flops

指定当前屏打印出 FLOPs 信息。

--log_params

指定当前屏打印出 Parameters 信息。

Python 中调用

以下代码等效于上方的命令行调用方式:

from megengine.tools.network_visualize import visualize

total_params, total_flops = visualize(
    "./resnet18.mge", "./log"
)
print("params {} flops {}".format(total_params, total_flops))

进行可视化

完成上面的步骤后,再在对应目录(例子中为 ./log )启动 tensorboard, 即可在本机打开 tensorboard 进程:

tensorboard --logdir ./log

注解

TensorBoard 的安装和使用请参考 TensorBoard 官网

如果启动服务器为远程 ssh 登陆,可用以下命令映射端口到本地(可使用 sshconfig 中的服务器名缩写):

ssh <user>@<host_name> -L 6006:0.0.0.0:6006 -N