# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm

def intrin_vadd(n):
    x = tvm.placeholder((n,), name='vx')
    y = tvm.placeholder((n,), name='vy')
    z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
    def intrin_func(ins, outs):
        xx, yy = ins
        zz = outs[0]
        return tvm.call_packed("vadd", xx, yy, zz)
    with tvm.build_config(offset_factor=16):
        return tvm.decl_tensor_intrin(z.op, intrin_func)

def intrin_gemv(m, n):
    w = tvm.placeholder((m, n), name='w')
    x = tvm.placeholder((n,), name='x')
    k = tvm.reduce_axis((0, n), name='k')
    z = tvm.compute((m,), lambda i:
                    tvm.sum(w[i, k] * x[k], axis=k), name='z')
    Wb = tvm.decl_buffer(w.shape, w.dtype,
                         name="W",
                         offset_factor=16,
                         strides=[tvm.var('ldw'), 1])
    def intrin_func(ins, outs):
        ww, xx = ins
        zz = outs[0]
        ww_ptr = ww.access_ptr("r")
        xx_ptr = xx.access_ptr("r")
        zz_ptr = zz.access_ptr("w")
        body = tvm.call_packed(
            "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
        reset = tvm.call_packed(
            "fill_zero", zz_ptr, n)
        update = tvm.call_packed(
            "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
        return body, reset, update

    with tvm.build_config(data_alignment=16,
                          offset_factor=16):
        return tvm.decl_tensor_intrin(z.op, intrin_func,
                                      binds={w: Wb})

def intrin_gemv_no_reset(m, n):
    w = tvm.placeholder((m, n), name='w')
    x = tvm.placeholder((n,), name='x')
    k = tvm.reduce_axis((0, n), name='k')
    z = tvm.compute((m,), lambda i:
                    tvm.sum(w[i, k] * x[k], axis=k), name='z')
    Wb = tvm.decl_buffer(w.shape, w.dtype,
                         name="W",
                         offset_factor=16,
                         strides=[tvm.var('ldw'), 1])
    def intrin_func(ins, outs):
        ww, xx = ins
        zz = outs[0]
        ww_ptr = ww.access_ptr("r")
        xx_ptr = xx.access_ptr("r")
        zz_ptr = zz.access_ptr("w")
        body = tvm.call_packed(
            "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
        update = tvm.call_packed(
            "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
        return body, None, update

    with tvm.build_config(data_alignment=16,
                          offset_factor=16):
        return tvm.decl_tensor_intrin(z.op, intrin_func,
                                      binds={w: Wb})


def test_tensorize_vadd():
    m = 128
    x = tvm.placeholder((m,), name='x')
    y = tvm.placeholder((m,), name='y')
    z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')

    def check(factor):
        s = tvm.create_schedule(z.op)
        xo, xi = s[z].split(z.op.axis[0], factor=factor)
        vadd = intrin_vadd(factor)
        s[z].tensorize(xi, vadd)
        s = s.normalize()
        dom_map = tvm.schedule.InferBound(s)
        finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
        out_dom, in_dom = finfer(s[z], dom_map)
        assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor)
        assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor)
        assert tvm.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor)
        fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
        body = fmatch(s[z], out_dom, in_dom, vadd)
        assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                                 tvm.ir_pass.CanonicalSimplify(vadd.op.body[0]))
        stmt = tvm.schedule.ScheduleOps(s, dom_map)
        tvm.lower(s, [x, y, z])

    check(16)


def test_tensorize_matmul():
    n = 1024
    m = n
    l = n
    A = tvm.placeholder((n, l), name='A')
    B = tvm.placeholder((m, l), name='B')
    k = tvm.reduce_axis((0, l), name='k')
    C = tvm.compute((n, m), lambda i, j:
                    tvm.sum(B[j, k] * A[i, k], axis=k), name='C')

    def check(factor):
        s = tvm.create_schedule(C.op)
        x, y = C.op.axis
        yo, yi = s[C].split(y, factor=factor)
        gemv = intrin_gemv(factor, l)
        s[C].tensorize(yi, gemv)
        s = s.normalize()
        dom_map = tvm.schedule.InferBound(s)
        finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
        out_dom, in_dom = finfer(s[C], dom_map)
        assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
        assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
        assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
        fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
        body = fmatch(s[C], out_dom, in_dom, gemv)
        assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                                 tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
        stmt = tvm.schedule.ScheduleOps(s, dom_map)
        tvm.lower(s, [A, B, C])


    def check_rfactor(factor, rfactor):
        s = tvm.create_schedule(C.op)
        x, y = C.op.axis
        rk = C.op.reduce_axis[0]
        yo, yi = s[C].split(y, factor=factor)
        ro, ri = s[C].split(rk, factor=rfactor)
        s[C].reorder(yo, ro, yi, ri)
        gemv = intrin_gemv(factor, rfactor)
        s[C].tensorize(yi, gemv)
        s = s.normalize()
        dom_map = tvm.schedule.InferBound(s)
        finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
        out_dom, in_dom = finfer(s[C], dom_map)
        assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
        assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
        assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
        fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
        body = fmatch(s[C], out_dom, in_dom, gemv)
        assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                                 tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
        stmt = tvm.schedule.ScheduleOps(s, dom_map)
        tvm.lower(s, [A, B, C])

    def check_rfactor_no_reset(factor, rfactor):
        s = tvm.create_schedule(C.op)
        x, y = C.op.axis
        rk = C.op.reduce_axis[0]
        yo, yi = s[C].split(y, factor=factor)
        ro, ri = s[C].split(rk, factor=rfactor)
        s[C].reorder(yo, ro, yi, ri)
        gemv = intrin_gemv_no_reset(factor, rfactor)
        s[C].tensorize(yi, gemv)
        s = s.normalize()
        dom_map = tvm.schedule.InferBound(s)
        finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
        out_dom, in_dom = finfer(s[C], dom_map)
        assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
        assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
        assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
        fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
        body = fmatch(s[C], out_dom, in_dom, gemv)
        assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                                 tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
        stmt = tvm.schedule.ScheduleOps(s, dom_map)
        tvm.lower(s, [A, B, C])

    def check_rfactor_no_reset_multi_reduction(factor, rfactor):
        s = tvm.create_schedule(C.op)
        x, y = C.op.axis
        rk = C.op.reduce_axis[0]
        yo, yi = s[C].split(y, factor=factor)
        ro, ri = s[C].split(rk, factor=rfactor)
        roo, roi = s[C].split(ro, factor=2)
        s[C].reorder(yo, roo, roi, yi, ri)
        gemv = intrin_gemv_no_reset(factor, rfactor)
        s[C].tensorize(yi, gemv)
        s = s.normalize()
        dom_map = tvm.schedule.InferBound(s)
        finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
        out_dom, in_dom = finfer(s[C], dom_map)
        assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
        assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
        assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
        fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
        body = fmatch(s[C], out_dom, in_dom, gemv)
        assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                                 tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
        stmt = tvm.schedule.ScheduleOps(s, dom_map)
        tvm.lower(s, [A, B, C])

    check(16)
    check_rfactor(16, 16)
    check_rfactor_no_reset(16, 16)
    check_rfactor_no_reset_multi_reduction(16, 16)

# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op():
    def op_intrin():
        bh = 9
        bw = 9
        x = tvm.placeholder((5, 5), name='A')
        y = tvm.compute((bh, bw), lambda i,j: x[j/3 + i%3, j%3+ i/3])

        def intrin_func(ins, outs):
            xx, = ins
            zz = outs[0]
            return tvm.call_packed("op", xx, zz)

        with tvm.build_config(offset_factor=2):
            return tvm.decl_tensor_intrin(y.op, intrin_func)

    A = tvm.placeholder((5, 5), name='A')
    B = tvm.compute((9,9), lambda i, j: A[j/3 + i%3, j%3 + i/3])
    bt = op_intrin()
    s = tvm.create_schedule(B.op)

    x,y = B.op.axis
    s[B].tensorize(x, bt)
    s = s.normalize()
    tvm.lower(s, [A, B])

# This test asserts that tensorize does not have any effect on
# TensorComputeOp operations
def test_tensorize_tensor_compute_op():
    # an intrinsic called "multivadd" whose definition (pattern)
    # is a loop of another intrinsic called "vadd"
    def intrin_multivadd(n):
        n_a = tvm.var("n_a")
        Ab = tvm.decl_buffer((n, ), tvm.float32, strides=[n_a])

        n_b = tvm.var("n_b")
        Bb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_b])

        n_c = tvm.var("n_c")
        Cb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_c])

        z = tvm.compute((n,), lambda i: tvm.call_extern("float32", 'vadd',
                                                        Ab.access_ptr("w", offset=n_a*i),
                                                        Bb.access_ptr("r", offset=n_b*i),
                                                        Cb.access_ptr("r", offset=n_c*i)))

        # replace the pattern with the multivadd call. I need to figure out
        # how to pass it the right parameters.
        def intrin_func(ins, outs):
            return tvm.call_packed("multivadd")

        with tvm.build_config():
            return tvm.decl_tensor_intrin(z.op, intrin_func, name="multivadd")

    def intrin_vadd(n):
        dtype = 'float32'
        x = tvm.placeholder((n,), dtype=dtype, name='vx')
        y = tvm.placeholder((n,), dtype=dtype, name='vy')
        z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
        s = tvm.create_schedule(z.op)

        def create_buffer(t):
            return tvm.decl_buffer(t.shape, t.dtype,
                                   name='W'+t.name,
                                   offset_factor=16)

        def intrin_func(ins, outs):
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_extern("float32", 'vadd',
                                    ins[0].access_ptr("r"), ins[1].access_ptr('r'),
                                    outs[0].access_ptr('wr')))
            return ib.get()

        with tvm.build_config(offset_factor=16):
            return tvm.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
                                                                    y: create_buffer(y),
                                                                    z: create_buffer(z)})

    # cache_read, cache_write
    M = 1024
    factor = 16
    dtype = 'float32'

    A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
    B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)

    vadd = intrin_vadd(factor)
    C = tvm.compute((M//factor, factor),
                    lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')

    s = tvm.create_schedule(C.op)
    multivadd = intrin_multivadd(64)
    s[C].tensorize(C.op.axis[0], multivadd)
    s = s.normalize()
    dom_map = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, dom_map)
    # The loop that we tried to tensorize still exists in the code
    # That means tensorize didn't work as expected
    assert isinstance(stmt.body.body.body, tvm.stmt.For)
    assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name



if __name__ == "__main__":
    test_tensorize_vadd()
    test_tensorize_matmul()
    test_tensorize_op()
    test_tensorize_tensor_compute_op()