# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test flop calculation"""

import tvm
import numpy as np

from tvm.autotvm.task.task import compute_flop

def random_dtypes():
    """Return pair of (input, accumulator) dtypes"""
    candidates = [("float32", "float32"), ("float16", "float32"), ("int8", "int32")]
    return candidates[np.random.choice(len(candidates))]

def test_conv():
    for i in range(5):
        N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
        (input_dtype, acc_dtype) = random_dtypes()
        D = tvm.placeholder((N, CI, H, W), dtype=input_dtype)
        K = tvm.placeholder((CO, CI, KH, KW), dtype=input_dtype)

        KH = min(H, KH)
        KW = min(W, KW)

        ci = tvm.reduce_axis((0, CI))
        kh = tvm.reduce_axis((0, KH))
        kw = tvm.reduce_axis((0, KW))

        OH = (H - KH) + 1
        OW = (W - KW) + 1

        C = tvm.compute((N, CO, OH, OW), lambda n, co, h, w:
        tvm.sum(D[n][ci][h][w].astype(acc_dtype) * K[co][ci][h][w].astype(acc_dtype),
                axis=[ci, kh, kw]))

        s = tvm.create_schedule([C.op])

        assert compute_flop(s) == 2 * N * CO * OH * OW * CI * KH * KW

def test_pack_gemm():
    for i in range(5):
        N, L, M = [np.random.randint(10, 128) * 4 for _ in range(3)]
        (input_dtype, acc_dtype) = random_dtypes()
        A = tvm.placeholder((N, L), dtype=input_dtype)
        B = tvm.placeholder((M, L), dtype=input_dtype)
        k = tvm.reduce_axis((0, L))

        bn = 4
        idxd = tvm.indexdiv
        idxm = tvm.indexmod

        A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
        B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
        C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
        tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
        C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])

        s = tvm.create_schedule([C.op])
        assert compute_flop(s) == 2 * N * L * M

def test_outer_dot():
    for i in range(5):
        N, M = [np.random.randint(10, 128) * 4 for _ in range(2)]
        (input_dtype, acc_dtype) = random_dtypes()
        A = tvm.placeholder((N,), dtype=input_dtype)
        B = tvm.placeholder((M,), dtype=input_dtype)

        C = tvm.compute((N, M), lambda i, j: A[i].astype(acc_dtype) * B[j].astype(acc_dtype))

        s = tvm.create_schedule([C.op])
        assert compute_flop(s) == N * M

def test_max_pool():
    for i in range(5):
        N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
        (input_dtype, _) = random_dtypes()
        D = tvm.placeholder((N, CI, H, W), dtype=input_dtype)

        KH = min(H, KH)
        KW = min(W, KW)

        kh = tvm.reduce_axis((0, KH))
        kw = tvm.reduce_axis((0, KW))

        OH = (H - KH) + 1
        OW = (W - KW) + 1

        C = tvm.compute(
            (N, CO, OH, OW),
            lambda n, co, h, w: tvm.max(D[n][co][h + kh][w + kw], axis=[kh, kw]))

        s = tvm.create_schedule([C.op])

        assert compute_flop(s) == N * CO * OH * OW * KH * KW

def test_average_pool():
    for i in range(5):
        N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
        (input_dtype, acc_dtype) = random_dtypes()
        D = tvm.placeholder((N, CI, H, W), dtype=input_dtype)

        KH = min(H, KH)
        KW = min(W, KW)

        kh = tvm.reduce_axis((0, KH))
        kw = tvm.reduce_axis((0, KW))

        OH = (H - KH) + 1
        OW = (W - KW) + 1


        C = tvm.compute(
            (N, CO, OH, OW),
            lambda n, co, h, w: tvm.sum(
                tvm.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw]))

        s = tvm.create_schedule([C.op])

        assert compute_flop(s) == 2 * N * CO * OH * OW * KH * KW

def test_move():
    """No float number operation in simple move. So the estimator should raise an error """
    N = 1024

    A = tvm.placeholder((N,))
    C = tvm.compute((N,), lambda i: A[i])
    s = tvm.create_schedule([C.op])

    try:
        compute_flop(s)
        assert False
    except RuntimeError:
        pass

if __name__ == '__main__':
    test_conv()
    test_pack_gemm()
    test_outer_dot()
    test_move()