# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#   http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm

def test_copy2d():
    m = tvm.var('m')
    l = tvm.var('l')
    A = tvm.placeholder((m, l), name='A')
    B = tvm.compute((m, l), lambda i, j: A[i, j], name='B')
    s = tvm.create_schedule(B.op)
    s[B].pragma(B.op.axis[0], "memcpy")
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
    def cb(src, dst, pad_before, pad_after, pad_value):
        assert dst.strides[0] == l
        assert dst.strides[1].value == 1
        assert src.strides[0] == l
        assert tuple(src.shape) == (m, l)
        return tvm.make.Evaluate(0)
    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def test_copy_pad():
    m = tvm.var('m')
    l = tvm.var('l')
    A = tvm.placeholder((m, l), name='A')
    B = tvm.compute((m + 2, l), lambda i, j:
                    tvm.if_then_else(tvm.all(i >= 1, i < m + 1),
                                     A[i - 1, j], 1.0), name='B')
    s = tvm.create_schedule(B.op)
    s[B].pragma(B.op.axis[0], "memcpy")
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
    def cb(src, dst, pad_before, pad_after, pad_value):
        assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
        assert pad_before[0].value == 1
        assert pad_before[1].value == 0
        assert pad_after[0].value == 1
        assert pad_after[1].value == 0
        assert pad_value.value == 1.0
        return tvm.make.Evaluate(0)
    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def test_single_point_test():
    A = tvm.placeholder((1,), name='A')
    B = tvm.compute((1,), lambda i:
                    A[i], name='B')
    s = tvm.create_schedule(B.op)
    s[B].pragma(B.op.axis[0], "memcpy")
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
    def cb(src, dst, pad_before, pad_after, pad_value):
        assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
        assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
        assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
        assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
        return tvm.make.Evaluate(0)
    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def assert_expr_equal(a, b):
    assert tvm.ir_pass.Simplify(a - b).value == 0

def test_copy_pad_split():
    m = 4 * 3
    A = tvm.placeholder((m, ), name="A")
    Apad = tvm.compute((m + 2,), lambda i:
                       tvm.if_then_else(tvm.all(i >= 1, i <= m),
                                        A[i - 1], 0.0), "Apad")
    B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2])
    s = tvm.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=4)
    s[Apad].compute_at(s[B], xo)
    s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
    stmt = tvm.ir_pass.Simplify(stmt)
    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
    def cb(src, dst, pad_before, pad_after, pad_value):
        assert(dst.elem_offset.value == 0)
        assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)

        rpad_before = tvm.max(1 - xo * 4, 0)
        rpad_after = tvm.max(xo * 4 - 7, 0)
        assert_expr_equal(pad_before[0], rpad_before)
        assert_expr_equal(pad_after[0], rpad_after)
        assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
        return tvm.make.Evaluate(0)
    stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

if __name__ == "__main__":