dcgan.py 2.46 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
# pylint: disable=unused-argument
The MXNet symbol of DCGAN generator

Adopted from:

Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).

import mxnet as mx

def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
    """a deconv layer that enlarges the feature map"""
    target_shape = (oshape[-2], oshape[-1])
    pad_y = (kshape[0] - 1) // 2
    pad_x = (kshape[1] - 1) // 2
    adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
    adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]

    net = mx.sym.Deconvolution(data,
                               pad=(pad_y, pad_x),
                               adj=(adj_y, adj_x),
    return net

def deconv2d_bn_relu(data, prefix, **kwargs):
    """a block of deconv + batch norm + relu"""
    eps = 1e-5 + 1e-12

    net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
    net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
    net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
    return net

def get_symbol(oshape=(3, 64, 64), ngf=128, code=None):
    """get symbol of dcgan generator"""
    assert oshape[-1] == 64, "Only support 64x64 image"
    assert oshape[-2] == 64, "Only support 64x64 image"

    code = mx.sym.Variable("data") if code is None else code
    net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False)
    net = mx.sym.Activation(net, act_type='relu')
    # 4 x 4
    net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4))
    # 8 x 8
    net = deconv2d_bn_relu(
        net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
    # 16x16
    net = deconv2d_bn_relu(
        net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
    # 32x32
    net = deconv2d_bn_relu(
        net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
    # 64x64
    net = deconv2d(
        net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
    net = mx.sym.Activation(net, act_type='tanh')
    return net