"""Test alter op layout pass"""

from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay.ir_pass import *

def test_alter_op():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    @register_alter_op_layout("nn.conv2d", level=100)
    def alter_conv2d(attrs, inputs, tinfos):
        data, weight = inputs
        weight = relay.multiply(weight, relay.const(2.0, "float32"))
        return relay.nn.conv2d(data, weight, **attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    a = before()
    a = infer_type(a)
    a = alter_op_layout(a)

    b = expected()
    b = infer_type(b)

    assert(alpha_equal(a, b))


def test_alter_return_none():
    """Test doing nothing by returning 'None' """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.nn.global_max_pool2d(x)
        y = relay.Function([x], y)
        return y

    called = [False]

    @register_alter_op_layout("nn.global_max_pool2d", level=101)
    def alter_conv2d(attrs, inputs, tinfos):
        called[0] = True
        return None

    a = before()
    a = infer_type(a)
    a = alter_op_layout(a)

    b = before()
    b = infer_type(b)
    assert(alpha_equal(a, b))
    assert(called[0])


def test_alter_layout():
    """Test alternating the layout of a conv2d.
    The layout of broadcast operators and the weight should be changed accordingly.
    """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        bias = relay.var("bias")
        weight = relay.var("weight")
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
        y = relay.nn.bias_add(y, bias)
        # a useless tuple, which will be eliminated
        y = relay.Tuple([y])[0]
        y = relay.nn.relu(y)
        y = relay.nn.batch_flatten(y)
        y = relay.Function(free_vars(y), y)
        return y

    @register_alter_op_layout("nn.conv2d", level=102)
    def alter_conv2d(attrs, inputs, tinfos):
        data, weight = inputs
        new_attrs = dict(attrs)
        new_attrs['data_layout'] = 'NCHW16c'
        new_attrs['weight_layout'] = 'OIHW16i'
        return relay.nn.conv2d(data, weight, **new_attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        bias = relay.var("bias", shape=(64,))
        weight = relay.var("weight", shape=(64, 64, 3, 3))

        y = relay.layout_transform(x, "NCHW", "NCHW16c")
        w = relay.layout_transform(weight, "OIHW", "OIHW16i")
        y = relay.nn.conv2d(y, w,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            weight_layout="OIHW16i",
                            data_layout="NCHW16c")
        b = relay.expand_dims(bias, axis=1, num_newaxis=2)
        b = relay.layout_transform(b, "CHW", "CHW16c")
        y = relay.add(y, b)

        y = relay.nn.relu(y)
        y = relay.layout_transform(y, "NCHW16c", "NCHW")
        y = relay.nn.batch_flatten(y)
        y = relay.Function(free_vars(y), y)
        return y

    a = before()
    a = infer_type(a)
    a = canonicalize_ops(a)
    a = infer_type(a)
    a = alter_op_layout(a)
    a = infer_type(a)

    b = expected()
    b = infer_type(b)

    assert(alpha_equal(a, b))


def test_alter_layout_dual_path():
    """
    Test alternating the layout with two outputs.
    One path continues to use the new layout while one path fall backs to old layout.
    """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight1 = relay.var('weight1')
        weight2 = relay.var('weight2')
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y1 = relay.nn.conv2d(y, weight2,
                             channels=32,
                             kernel_size=(3, 3),
                             padding=(1, 1))
        y1 = relay.nn.relu(y1)
        y2 = relay.nn.batch_flatten(y)
        ret = relay.Tuple([y1, y2])
        y = relay.Function(free_vars(ret), ret)
        return y

    @register_alter_op_layout("nn.conv2d", level=103)
    def alter_conv2d(attrs, inputs, tinfos):
        data, weight = inputs
        new_attrs = dict(attrs)
        new_attrs['data_layout'] = 'NCHW16c'
        return relay.nn.conv2d(data, weight, **new_attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight1 = relay.var('weight1')
        weight2 = relay.var('weight2')
        y = relay.layout_transform(x, "NCHW", "NCHW16c")
        y = relay.nn.conv2d(y, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout="NCHW16c")
        y = relay.nn.relu(y)
        y1 = relay.nn.conv2d(y, weight2,
                             channels=32,
                             kernel_size=(3, 3),
                             padding=(1, 1),
                             data_layout='NCHW16c')
        y1 = relay.nn.relu(y1)
        y1 = relay.layout_transform(y1, "NCHW16c", "NCHW")
        y2 = relay.layout_transform(y, "NCHW16c", "NCHW")
        y2 = relay.nn.batch_flatten(y2)
        ret = relay.Tuple([y1, y2])
        y = relay.Function(free_vars(ret), ret)
        return y

    a = before()
    a = infer_type(a)
    a = alter_op_layout(a)
    a = infer_type(a)

    b = expected()
    b = infer_type(b)

    assert(alpha_equal(a, b))

def test_alter_layout_resnet():
    """Test alternating the layout of a residual block
    This also tests the elimination of duplicated transformation.
    If a same transformation applies to a same node twice, only one transformation will be created.
    """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight1 = relay.var('weight1')
        weight2 = relay.var('weight2')
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y2 = relay.nn.conv2d(x, weight2,
                             channels=32,
                             kernel_size=(1, 1))
        y2 = relay.nn.relu(y2)
        y = y + y2
        y = relay.nn.global_max_pool2d(y)
        return relay.Function(free_vars(y), y)

    @register_alter_op_layout("nn.conv2d", level=104)
    def alter_conv2d(attrs, inputs, tinfos):
        data, weight = inputs
        new_attrs = dict(attrs)
        new_attrs['data_layout'] = 'NCHW16c'
        return relay.nn.conv2d(data, weight, **new_attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight1 = relay.var('weight1')
        weight2 = relay.var('weight2')
        x = relay.layout_transform(x, "NCHW", "NCHW16c")
        y = relay.nn.conv2d(x, weight1,
                            channels=32,
                            kernel_size=(3, 3),
                            padding=(1, 1),
                            data_layout="NCHW16c")
        y = relay.nn.relu(y)
        y2 = relay.nn.conv2d(x, weight2,
                             channels=32,
                             kernel_size=(1, 1),
                             data_layout='NCHW16c')
        y2 = relay.nn.relu(y2)
        y = y + y2
        y = relay.nn.global_max_pool2d(y, layout="NCHW16c")
        y = relay.layout_transform(y, "NCHW16c", "NCHW")
        return relay.Function(free_vars(y), y)

    a = before()
    a = infer_type(a)
    a = alter_op_layout(a)
    a = infer_type(a)

    b = expected()
    b = infer_type(b)

    assert(alpha_equal(a, b))


def test_alter_layout_broadcast_op():
    """Test boradcast operators """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        bias = relay.var("bias", shape=(64,))
        scale = relay.var("scale", shape=(64, 1, 1))
        weight = relay.var("weight")
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
        y = relay.nn.bias_add(y, bias) # test broadcasting to lhs
        y = relay.multiply(scale, y)         # test broadcasting to rhs
        y = relay.Function(free_vars(y), y)
        return y

    @register_alter_op_layout("nn.conv2d", level=102)
    def alter_conv2d(attrs, inputs, tinfos):
        data, weight = inputs
        new_attrs = dict(attrs)
        new_attrs['data_layout'] = 'NCHW16c'
        return relay.nn.conv2d(data, weight, **new_attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        bias = relay.var("bias", shape=(64,))
        scale = relay.var("scale", shape=(64, 1, 1))
        weight = relay.var("weight")
        x = relay.layout_transform(x, "NCHW", "NCHW16c")
        bias = relay.expand_dims(bias, 1, 2)
        bias = relay.layout_transform(bias, "CHW", "CHW16c")
        scale = relay.layout_transform(scale, "CHW", "CHW16c")
        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
                            data_layout="NCHW16c")
        y = relay.add(y, bias)          # test broadcasting to lhs
        y = relay.multiply(scale, y)      # test broadcasting to rhs
        y = relay.layout_transform(y, "NCHW16c", "NCHW")
        y = relay.Function(free_vars(y), y)
        return y

    a = before()
    a = infer_type(a)
    a = canonicalize_ops(a)
    a = infer_type(a)
    a = alter_op_layout(a)
    a = infer_type(a)

    b = expected()
    b = infer_type(b)

    assert(alpha_equal(a, b))


if __name__ == "__main__":
    test_alter_op()
    test_alter_return_none()
    test_alter_layout()
    test_alter_layout_dual_path()
    test_alter_layout_resnet()
    test_alter_layout_broadcast_op()