def mxnet_check():
    """This is a simple test function for MXNet bridge

    It is not included as nosetests, because of its dependency on mxnet

    User can directly run this script to verify correctness.
    """
    import mxnet as mx
    import topi
    import tvm
    import numpy as np
    from tvm.contrib.mxnet import to_mxnet_func

    # build a TVM function through topi
    n = 20
    shape = (20,)
    scale = tvm.var("scale", dtype="float32")
    x = tvm.placeholder(shape)
    y = tvm.placeholder(shape)
    z = topi.broadcast_add(x, y)
    zz = tvm.compute(shape, lambda *i: z(*i) * scale)

    target = tvm.target.cuda()

    # build the function
    with target:
        s = topi.generic.schedule_injective(zz)
        f = tvm.build(s, [x, y, zz, scale])

    # get a mxnet version
    mxf = to_mxnet_func(f, const_loc=[0, 1])

    ctx = mx.gpu(0)
    xx = mx.nd.uniform(shape=shape, ctx=ctx)
    yy = mx.nd.uniform(shape=shape, ctx=ctx)
    zz = mx.nd.empty(shape=shape, ctx=ctx)

    # invoke myf: this runs in mxnet engine
    mxf(xx, yy, zz, 10.0)
    mxf(xx, yy, zz, 10.0)


    tvm.testing.assert_allclose(
        zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)


if __name__ == "__main__":
    mxnet_check()