模型中心
Generative Adversarial Networks
Generative Adversarial Networks
vision
gan
生成对抗网络(Cifar10 预训练权重)
import megengine.hub as hub
import megengine_mimicry.nets.dcgan.dcgan_cifar as dcgan
import megengine_mimicry.utils.vis as vis
netG = dcgan.DCGANGeneratorCIFAR()
netG.load_state_dict(hub.load_serialized_obj_from_url("https://data.megengine.org.cn/models/weights/dcgan_cifar.pkl"))
images = dcgan_generator.generate_images(num_images=64) # in NCHW format with normalized pixel values in [0, 1]
grid = vis.make_grid(images) # in HW3 format with [0, 255] BGR images for visualization
vis.save_image(grid, "visual.png")
训练参数
分辨率 | 批大小 | 学习率 | β<sub>1</sub> | β<sub>2</sub> | 衰减法则 | n<sub>dis</sub> | n<sub>iter</sub> |
---|---|---|---|---|---|---|---|
32 x 32 | 64 | 2e-4 | 0.0 | 0.9 | Linear | 5 | 100K |
评测指标
Metric | Method |
---|---|
Inception Score (IS) | 分成10份共计 50K 样本 |
Fréchet Inception Distance (FID) | 50K 真实/生成样本 |
Kernel Inception Distance (KID) | 50K 真实/生成样本, 分成10份取平均值 |
Cifar10 结果
模型 | FID Score | IS Score | KID Score |
---|---|---|---|
DCGAN | 27.2 | 7.0 | 0.0242 |
WGAN-WC | 30.5 | 6.7 | 0.0249 |
参考文献
- Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks, Alec Radford, Luke Metz, and Soumith Chintala.
- Wasserstein GAN, Martin Arjovsky, Soumith Chintala, and Léon Bottou.