# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform


def rand(dtype='float32', *shape):
    return tvm.nd.array(np.random.rand(*shape).astype(dtype))


def test_id():
    x = relay.var("x", shape=[])
    id = run_infer_type(relay.Function([x], x))
    id_cps = run_infer_type(to_cps(id))


def test_double():
    t = relay.TypeVar("t")
    x = relay.var("x", t)
    f = relay.var("f", relay.FuncType([t], t))
    double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t]))
    double_cps = run_infer_type(to_cps(double))


# make sure cps work for recursion.
def test_recursion():
    mod = relay.Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    shape = (10, 10)
    dtype = 'float32'
    t = relay.TensorType(shape, dtype)
    x = relay.var("x", t)
    double = relay.Function([x], x + x)
    i = relay.var("i", t)
    func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
    mod["main"] = func
    mod["main"] = to_cps(mod["main"], mod=mod)
    mod["main"] = un_cps(mod["main"])
    ex = create_executor(mod=mod)
    i_nd = rand(dtype, *shape)
    forward = ex.evaluate()(i_nd)
    tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())


# This serve as an integration test.
# It test that, given a program with reference,
# cps and pe can completely eliminate the allocation of reference.
def test_cps_pe():
    def destroy_ref(x):
        x = run_infer_type(x)
        x = to_cps(x)
        x = run_infer_type(x)
        y = un_cps(x)
        y = run_infer_type(y)
        x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
        assert Feature.fRefCreate not in detect_feature(x)
    unit = relay.Function([], relay.const(0., dtype='float32'))
    f_ref = relay.Var("f_ref")

    one = relay.const(1., dtype='float32')
    two = relay.const(2., dtype='float32')
    cond = relay.var(shape=(), dtype='uint1', name_hint='cond')
    true_branch = relay.RefWrite(f_ref, relay.Function([], one))
    false_branch = relay.RefWrite(f_ref, relay.Function([], two))
    if_expr = relay.If(cond, true_branch, false_branch)

    stmt = relay.Let(f_ref, relay.RefCreate(unit),
                     relay.Let(relay.Var("x"), if_expr,
                               relay.Call(relay.RefRead(f_ref), [])))

    F = relay.Function([cond], stmt)
    destroy_ref(F)

    G = relay.Function([cond], relay.If(cond, one, two))
    G = run_infer_type(G)
    G = relay.transform.gradient(G)
    destroy_ref(G)

    x = relay.var("x", shape=(1, 16))
    y = relay.var("y", shape=(1, 16))
    z = relay.var("z", shape=(1, 16))
    cond = relay.var("cond", shape=(), dtype='uint1')
    H = relay.If(cond, x, y)
    H = relay.add(H, z)
    H = relay.Function([cond,x,y,z], H)
    H = run_infer_type(H)
    H = relay.transform.gradient(H)
    destroy_ref(H)


if __name__ == '__main__':
    test_recursion()
    test_cps_pe()