AllreduceCallback

class AllreduceCallback(reduce_method, group=WORLD, backend=None)[源代码]

具有张量融合优化的 Allreduce 回调函数。

参数
  • reduce_method (str) – the method to reduce gradiants. reduce_method should be “sum” or “mean”.

  • group( – attr:`.distributed.group.Group, optional): communication group. Default: WORLD.

  • backend (str, optional) – override distributed backend in allreduce. If backend is None, will use the backend set in dist.launcher. Default: None.

实际案例

import megengine as mge
import megengine.autodiff as ad
import megengine.distributed as dist

gm = ad.GradManager()
gm.attach(linear_cls.parameters(), callbacks=[dist.make_allreduce_cb("sum")])