Generative Adversarial Networks
Generative Adversarial Networks
vision
gan
开发者: MegEngine Team
生成对抗网络(Cifar10 预训练权重)
import megengine.hub as hub
import megengine_mimicry.nets.dcgan.dcgan_cifar as dcgan
import megengine_mimicry.utils.vis as vis

netG = dcgan.DCGANGeneratorCIFAR()
netG.load_state_dict(hub.load_serialized_obj_from_url("https://data.megengine.org.cn/models/weights/dcgan_cifar.pkl"))
images = dcgan_generator.generate_images(num_images=64)  # in NCHW format with normalized pixel values in [0, 1]
grid = vis.make_grid(images)  # in HW3 format with [0, 255] BGR images for visualization
vis.save_image(grid, "visual.png")

训练参数

分辨率 批大小 学习率 β<sub>1</sub> β<sub>2</sub> 衰减法则 n<sub>dis</sub> n<sub>iter</sub>
32 x 32 64 2e-4 0.0 0.9 Linear 5 100K

评测指标

Metric Method
Inception Score (IS) 分成10份共计 50K 样本
Fréchet Inception Distance (FID) 50K 真实/生成样本
Kernel Inception Distance (KID) 50K 真实/生成样本, 分成10份取平均值

Cifar10 结果

模型 FID Score IS Score KID Score
DCGAN 27.2 7.0 0.0242
WGAN-WC 30.5 6.7 0.0249

参考文献