# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common utilities for testing autotvm"""
import time

import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult

@autotvm.template
def matmul(N, L, M, dtype):
    A = tvm.placeholder((N, L), name='A', dtype=dtype)
    B = tvm.placeholder((L, M), name='B', dtype=dtype)

    k = tvm.reduce_axis((0, L), name='k')
    C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
    s = tvm.create_schedule(C.op)

    # schedule
    y, x = s[C].op.axis
    k = s[C].op.reduce_axis[0]

    ##### define space begin #####
    cfg = autotvm.get_config()
    cfg.define_split("tile_y", y, num_outputs=2)
    cfg.define_split("tile_x", x, num_outputs=2)
    ##### define space end #####

    # schedule according to config
    yo, yi = cfg["tile_y"].apply(s, C, y)
    xo, xi = cfg["tile_x"].apply(s, C, x)

    s[C].reorder(yo, xo, k, yi, xi)

    return s, [A, B, C]

@autotvm.template
def bad_matmul(N, L, M, dtype):
    if 'bad_device' in tvm.target.current_target().keys:
        A = tvm.placeholder((N, L), name='A', dtype=dtype)
        B = tvm.placeholder((L, M), name='B', dtype=dtype)

        k = tvm.reduce_axis((0, L-1), name='k')
        C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
        s = tvm.create_schedule(C.op)

        # schedule
        y, x = s[C].op.axis
        cfg = autotvm.get_config()
        cfg.define_split("tile_y", y, num_outputs=2)
        cfg.define_split("tile_x", x, num_outputs=2)
        return s, [A, B, C]

    return matmul(N, L, M, dtype)

def get_sample_task(n=128):
    """return a sample task for testing"""
    target = tvm.target.create("llvm")
    task = autotvm.task.create(matmul, args=(n, n, n, 'float32'), target=target)
    return task, target

def get_sample_records(n):
    """get sample records for testing"""
    tsk, target = get_sample_task()

    inps, ress = [], []
    for i in range(n):
        inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
        ress.append(MeasureResult((i+1,), 0, i, time.time()))
    return list(zip(inps, ress))