Commit 2fb1cc6e by Josh Pollock Committed by Tianqi Chen

[Relay] DCGAN port (#2010)

parent a1dfb9ae
...@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs ...@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from . import mlp from . import mlp
from . import resnet from . import resnet
from . import dqn from . import dqn
from . import dcgan
# pylint: disable=unused-argument
"""
Net of the generator of DCGAN
Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
from tvm import relay
from . import layers
from .init import create_workload
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
net = layers.conv2d_transpose(data,
kernel_size=kshape,
strides=stride,
channels=oshape[0],
padding=(pad_y, pad_x),
output_padding=(adj_y, adj_x),
name=name)
return net
def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = layers.batch_norm_infer(net, epsilon=eps, name="batch_norm")
net = relay.nn.relu(net)
return net
def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None, dtype="float32"):
"""get net of dcgan generator"""
assert oshape[-1] == 64, "Only support 64x64 image"
assert oshape[-2] == 64, "Only support 64x64 image"
code = relay.var("data", dtype=dtype, shape=(batch_size, random_len)) if code is None else code
dense_weight = relay.var("dense_weight")
dense = relay.nn.dense(code, weight=dense_weight, units=4*4*ngf*8)
relu = relay.nn.relu(dense)
# 4 x 4
reshape = relay.reshape(relu, newshape=(-1, ngf * 8, 4, 4))
# 8 x 8
dc8 = deconv2d_bn_relu(
reshape, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
dc16 = deconv2d_bn_relu(
dc8, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
dc32 = deconv2d_bn_relu(
dc16, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
# 64x64
dc64 = deconv2d(
dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
tanh = relay.tanh(dc64)
args = relay.ir_pass.free_vars(tanh)
return relay.Function(args, tanh)
def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype="float32"):
"""Get benchmark workload for a DCGAN generator
Parameters
----------
batch_size : int
The batch size used in the model
oshape : tuple, optional
The shape of output image, layout="CHW"
ngf: int, optional
The number of final feature maps in the generator
random_len : int, optional
The length of random input
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, random_len, oshape=oshape, ngf=ngf, dtype=dtype)
return create_workload(net)
...@@ -30,15 +30,25 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" ...@@ -30,15 +30,25 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"
"""get symbol of nature dqn""" """get symbol of nature dqn"""
data_shape = (batch_size,) + image_shape data_shape = (batch_size,) + image_shape
data = relay.var("data", shape=data_shape, dtype=dtype) data = relay.var("data", shape=data_shape, dtype=dtype)
conv1_bias = relay.var("conv1_bias")
conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name="conv1") channels=32, name="conv1")
conv1 = relay.nn.bias_add(conv1, conv1_bias)
relu1 = relay.nn.relu(conv1) relu1 = relay.nn.relu(conv1)
conv2_bias = relay.var("conv2_bias")
conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
channels=64, name="conv2") channels=64, name="conv2")
conv2 = relay.nn.bias_add(conv2, conv2_bias)
relu2 = relay.nn.relu(conv2) relu2 = relay.nn.relu(conv2)
conv3_bias = relay.var("conv3_bias")
conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
channels=64, name="conv3") channels=64, name="conv3")
conv3 = relay.nn.bias_add(conv3, conv3_bias)
relu3 = relay.nn.relu(conv3) relu3 = relay.nn.relu(conv3)
bf1 = relay.nn.batch_flatten(relu3) bf1 = relay.nn.batch_flatten(relu3)
dense1 = layers.dense_add_bias(bf1, units=512, name="dense1") dense1 = layers.dense_add_bias(bf1, units=512, name="dense1")
relu4 = relay.nn.relu(dense1) relu4 = relay.nn.relu(dense1)
......
...@@ -80,6 +80,30 @@ def conv2d(data, weight=None, **kwargs): ...@@ -80,6 +80,30 @@ def conv2d(data, weight=None, **kwargs):
weight = relay.var(name + "_weight") weight = relay.var(name + "_weight")
return relay.nn.conv2d(data, weight, **kwargs) return relay.nn.conv2d(data, weight, **kwargs)
def conv2d_transpose(data, weight=None, **kwargs):
"""Wrapper of conv2d_transpose which automatically creates weights if not given.
Parameters
----------
data : relay.Expr
The input expression.
weight : relay.Expr
The weight to conv2d_transpose.
kwargs : dict
Additional arguments.
Returns
-------
result : relay.Expr
The result.
"""
name = kwargs.get("name")
kwargs.pop("name")
if not weight:
weight = relay.var(name + "_weight")
return relay.nn.conv2d_transpose(data, weight, **kwargs)
def dense_add_bias(data, weight=None, bias=None, **kwargs): def dense_add_bias(data, weight=None, bias=None, **kwargs):
"""Wrapper of dense which automatically creates weights if not given. """Wrapper of dense which automatically creates weights if not given.
......
...@@ -106,13 +106,18 @@ def test_resnet(): ...@@ -106,13 +106,18 @@ def test_resnet():
def test_dqn(): def test_dqn():
net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
show(net.astext()) net.astext()
def test_dcgan():
net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
net.astext()
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_resnet() test_resnet()
test_mlp() test_mlp()
test_dqn() test_dqn()
test_dcgan()
test_func() test_func()
test_env() test_env()
test_meta_data() test_meta_data()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment