# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy

def collect_visit(stmt, f):
    ret = []
    tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
    return ret

def find_top_produce(stmt):
    def f(x, ret):
        if isinstance(x, tvm.tir.ProducerConsumer):
            ret.append(x)
    ret = []
    tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret))
    return ret[-1]

def lower(sch, args):
    binds = {}
    arg_list = []
    for x in args:
        if isinstance(x, te.tensor.Tensor):
            buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
            assert x not in binds
            binds[x] = buf
            arg_list.append(buf)
        else:
            raise ValueError("args must be Tensor, Buffer or Var")
    sch = sch.normalize()
    bounds = tvm.te.schedule.InferBound(sch)
    stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
    stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    return stmt

def test_basic():
    n = te.size_var('n')
    A = te.placeholder((n, ), name='A')
    B = te.placeholder((n, ), name='B')

    T = te.compute((n, ), lambda i: A[i]+B[i])
    s = te.create_schedule(T.op)
    xo, xi = s[T].split(T.op.axis[0], factor=4)

    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert('if' not in str(stmt.body.body.body[0]))
    assert('if' in str(stmt.body.body.body[1]))

def test_const_loop():
    n = 21
    A = te.placeholder((n, ), name='A')
    B = te.placeholder((n, ), name='B')

    T = te.compute((n, ), lambda i: A[i]+B[i])
    s = te.create_schedule(T.op)
    xo, xi = s[T].split(T.op.axis[0], factor=4)

    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert('if' not in str(stmt.body.body.body[0]))

def test_multi_loop():
    ib = tvm.tir.ir_builder.create()
    m = te.size_var('m')
    n = te.size_var('n')
    with ib.for_range(0, 4, "i") as i:
        with ib.for_range(0, n, "j") as j:
            with ib.for_range(0, m, "k") as k:
                with ib.if_scope(ib.likely(i*m+j+k < n)):
                    ib.emit(tvm.tir.Evaluate(m))
                with ib.else_scope():
                    ib.emit(tvm.tir.Evaluate(n))
    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_multi_if():
    ib = tvm.tir.ir_builder.create()
    m = te.size_var('m')
    n = te.size_var('n')
    with ib.for_range(0, 4, 'i') as i:
        with ib.for_range(0, n, 'j') as j:
            with ib.for_range(0, m, 'k') as k:
                with ib.if_scope(ib.likely(i*m+j+k < n)):
                    ib.emit(tvm.tir.Evaluate(m))
                with ib.else_scope():
                    ib.emit(tvm.tir.Evaluate(n))
                with ib.if_scope(ib.likely(i*m+j-k < n)):
                    ib.emit(tvm.tir.Evaluate(m))
                with ib.else_scope():
                    ib.emit(tvm.tir.Evaluate(n))
    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert('if' not in str(stmt.body[0]))

def test_thread_axis():
    m = te.size_var('m')
    l = te.size_var('l')
    A = te.placeholder((m, l), name='A')
    B = te.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
    s = te.create_schedule(B.op)

    s[B].set_scope("shared")
    num_thread = 16
    xo, xi = s[B].split(B.op.axis[0], 32)
    xi0, xi1 = s[B].split(xi, nparts=num_thread)
    s[B].bind(xi0, te.thread_axis("threadIdx.x"))

    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert('if' not in str(stmt.body.body.body[0]))

def test_vectorize():
    n = te.size_var('n')
    A = te.placeholder((n,), name='A')
    B = te.placeholder((n,), name='B')
    bias = te.size_var("bias", dtype="float32")
    scale = te.size_var("scale", dtype="float32")
    C = te.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
    # schedule
    s = te.create_schedule(C.op)
    # create iter var and assign them tags.
    num_thread = 32
    bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
    tx, x = s[C].split(x, nparts=num_thread)
    _, x = s[C].split(x, factor=4)
    s[C].bind(bx, te.thread_axis("blockIdx.x"))
    s[C].bind(tx, te.thread_axis("threadIdx.x"))
    s[C].vectorize(x)
    stmt = lower(s, [A, B])
    body = stmt.body.body.body.body.body
    assert(x.var.name not in str(body.condition))
    assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))

def test_condition():
    ib = tvm.tir.ir_builder.create()
    m = te.size_var('m')
    n = te.size_var('n')
    with ib.for_range(0, tvm.tir.truncdiv(n+3,4), 'i') as i:
      with ib.for_range(0, 4, 'j') as j:
        ib.emit(tvm.tir.Evaluate(
          tvm.tir.Select(ib.likely(i*4+j<n), m, n)))
    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))

def test_condition_EQ():
    ib = tvm.tir.ir_builder.create()
    m = te.size_var('m')
    n = te.size_var('n')
    with ib.for_range(0, 10, 'i') as i:
            ib.emit(tvm.tir.Evaluate(
                tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))))

def test_thread_axis2():
    n = tvm.runtime.convert(4096)
    m = te.size_var('m')
    A = te.placeholder((n,), name='A')
    B = te.placeholder((n,), name='B')
    C = te.compute(A.shape, lambda i: A[i] + B[i], name='C')
    s = te.create_schedule(C.op)
    num_thread = 32
    bx, x = s[C].split(C.op.axis[0], factor=32)
    tx, x = s[C].split(x, nparts=num_thread)
    _,  x = s[C].split(x, factor=m)
    s[C].bind(bx, te.thread_axis("blockIdx.x"))
    s[C].bind(tx, te.thread_axis("threadIdx.x"))
    stmt = lower(s, [A, B])
    for_body = stmt.body.body.body.body.body[0]
    assert('threadIdx' not in str(for_body.extent))

def test_everything_during_deduction():
    m = te.size_var('m')
    n = te.size_var('n')
    ib = tvm.tir.ir_builder.create()
    with ib.for_range(0, n, 'i') as i:
        with ib.for_range(0, 32, 'j') as j:
            with ib.if_scope(ib.likely(tvm.tir.truncdiv(i,j) < m)):
                # this guard will produce everything during deduction
                ib.emit(tvm.tir.Evaluate(m))
    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(isinstance(stmt.body.body, tvm.tir.IfThenElse))

def test_single_likely():
    n = 60
    A = te.placeholder((n, ), name='A')
    B = te.placeholder((n, ), name='B')

    T = te.compute((n, ), lambda i: A[i]+B[i])
    s = te.create_schedule(T.op)
    x = T.op.axis[0]
    xo, xi = s[T].split(x, factor=16)

    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_multi_likely():
    n = 94
    m = 62
    A = te.placeholder((n, m), name='A')
    B = te.placeholder((n, m), name='B')

    T = te.compute((n, m), lambda i, j: A[i, j]+B[i, j])
    s = te.create_schedule(T.op)
    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    x, y = T.op.axis
    xo, xi = s[T].split(x, factor=16)
    yo, yi = s[T].split(y, factor=16)
    s[T].reorder(xo, yo, xi, yi)

    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_oneD_pool():
    m = te.size_var('m')
    ib = tvm.tir.ir_builder.create()
    #data = te.placeholder((16,), name = 'data')
    data = ib.pointer("float32", name="A")
    out = ib.pointer("float32", name="A")
    with ib.for_range(0, 16, 'ow') as ow:
        with ib.for_range(0, 3, 'kw') as kw:
            with ib.if_scope(ib.likely(ow > 0)):
                with ib.if_scope(ib.likely(ow < 15)):
                    out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
    with ib.for_range(0, 16, 'ow') as ow:
        with ib.for_range(0, 3, 'kw') as kw:
            with ib.if_scope(ib.likely(ow < 1)):
                with ib.if_scope(ib.likely(kw > 0)):
                    out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
    with ib.for_range(0, 16, 'ow') as ow:
        with ib.for_range(0, 3, 'kw') as kw:
            with ib.if_scope(ib.likely(ow > 14)):
                with ib.if_scope(ib.likely(kw < 2)):
                    out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])

    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_cce_loop_1():
  ib = tvm.tir.ir_builder.create()
  dtype = 'float16'
  n = 514
  m = 514
  _A = te.placeholder((n*m,), name = 'A')
  Ab = tvm.tir.decl_buffer((n*m,), dtype, name="A")
  A = ib.buffer_ptr(Ab)
  _B = te.placeholder((n*m,), name = 'B')
  Bb = tvm.tir.decl_buffer((n*m,), dtype, name="B")
  B = ib.buffer_ptr(Bb)
  #for i in 0 to n-1:
  with ib.for_range(0, 11, name="i") as i:
      with ib.for_range(0, 160, name="j") as j:
          with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
               A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
  stmt = ib.get()
  stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
  stmt = tvm.tir.ir_pass.Simplify(stmt)
  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_cce_loop_2():
  ib = tvm.tir.ir_builder.create()
  len = 112
  tile = 32
  loop = (len + tile - 1) // tile
  with ib.for_range(0, loop, 'i') as i:
    head = i * tile
    with ib.if_scope(ib.likely(head + tile > len)):
      tail = len
      ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))
    with ib.else_scope():
      tail = head + tile
      ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))

  stmt = ib.get()
  stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
  stmt = tvm.tir.ir_pass.Simplify(stmt)
  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))


def test_cce_loop_3():
    ib = tvm.tir.ir_builder.create()
    loop1 = 4
    loop2 = 9998
    tile = 39991
    with ib.for_range(0,loop2,'i') as i:
        with ib.for_range(0,loop1,'j') as j:
            head1 = i
            head2 = j
            with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)):
                ib.emit(tvm.tir.call_extern('float16',"cce_intrisic",head1))

    stmt = ib.get()
    stmt = tvm.tir.ir_pass.LoopPartition(stmt,True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_conv_tiling():
    HSTR = WSTR = 1
    in_channel = 128
    kernel_height = kernel_width = 3
    out_channel = 64
    batch_size = 1
    in_height = in_width = 64
    out_height = out_width = in_height - kernel_height + 1
    data = te.placeholder((batch_size, in_channel, in_height, in_width), name='data')
    kernel = te.placeholder((kernel_height, kernel_width, in_channel,
        out_channel), name='kernel')
    ic = te.reduce_axis((0, in_channel), name='ic')
    kh = te.reduce_axis((0, kernel_height), name='kh')
    kw = te.reduce_axis((0, kernel_width), name='kw')
    conv = te.compute((batch_size, out_channel, out_height, out_width),
                       lambda n, oc, oh, ow: te.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
                                                     kernel[kh, kw, ic, oc],
                                                     axis=[ic, kh, kw]),
                       name="conv2d")
    s = te.create_schedule(conv.op)

    n, oc, oh, ow = conv.op.axis
    oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.tir.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))


def test_multilevel_splitting_with_indivisble_factors():
    import topi
    A = te.placeholder((130,), dtype="float32")
    B = topi.nn.relu(A)
    s = te.create_schedule(B.op)
    (y,) = s[B].op.axis
    (yo, yi) = s[B].split(y, factor=8)
    (yoo, yoi) = s[B].split(yo, factor=16)
    s[B].reorder(yoo, yoi, yi)
    s[B].unroll(yi)

    ## But this does the right thing.
    with tvm.target.build_config(partition_const_loop=True):
        lowered_body = tvm.lower(s, [A, B]).body
        def visit_stmt(op):
            return(isinstance(op, tvm.tir.Max))
        num_max = collect_visit(lowered_body, visit_stmt)
        assert num_max.count(True) == 10


def test_double_splitting_with_indivisible_factors():
    m = 48
    dtype="float32"
    A = te.placeholder((m,), name='A', dtype=dtype)
    C = te.compute((m,), lambda i: A[i], name='C')
    D = te.compute((m,), lambda i: C[i], name='D')

    s = te.create_schedule(D.op)
    co, ci = s[C].split(C.op.axis[0], factor=10)
    do, di = s[D].split(D.op.axis[0], 32)
    s[C].compute_at(s[D], do)

    target = 'llvm'
    with tvm.target.build_config(partition_const_loop=True):
        f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
        func = tvm.build(f, target=target)

    # Find the beginning of the Halide IR corresponding to kernel code
    # and make sure it doesn't have an if statements left
    top_produce = find_top_produce(f.body)
    assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))

    # check functional correctness of generated code
    ctx = tvm.context(target, 0)
    a = tvm.nd.array(numpy.ones(m,).astype(dtype), ctx)
    c = tvm.nd.array(numpy.zeros(m,).astype(dtype), ctx)
    d = tvm.nd.array(numpy.zeros(m,).astype(dtype), ctx)
    func(a, c, d)
    tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy(), rtol=1e-5)
    tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy(), rtol=1e-5)

def test_simple_rfactor():
    K = 16*4+4
    k = te.reduce_axis((0, K), 'k')

    A = te.placeholder((1, K), name='A')

    B = te.compute( (1,), lambda b:
            te.sum(A[b, k], axis=k),
            name='B'
    )

    s = te.create_schedule(B.op)
    ko, _ = s[B].split(s[B].op.reduce_axis[0], 16)
    BF = s.rfactor(B, ko, 0)

    s.normalize()
    bounds = tvm.te.schedule.InferBound(s)

    stmt1 = tvm.te.schedule.ScheduleOps(s, bounds)
    stmt1 = tvm.tir.ir_pass.Simplify(stmt1)

    stmt2 = tvm.tir.ir_pass.LoopPartition(stmt1, True)
    stmt2 = tvm.tir.ir_pass.Simplify(stmt2)

    #make sure loop partition actually did something
    assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)


if __name__ == "__main__":
    test_basic()
    test_const_loop()
    test_multi_loop()
    test_multi_if()
    test_thread_axis()
    test_vectorize()
    test_condition()
    test_condition_EQ()
    test_thread_axis2()
    test_everything_during_deduction()
    test_single_likely()
    test_multi_likely()
    test_oneD_pool()
    test_cce_loop_1()
    test_cce_loop_2()
    test_cce_loop_3()
    test_conv_tiling()
    test_double_splitting_with_indivisible_factors()
    test_multilevel_splitting_with_indivisble_factors()
    test_simple_rfactor()