"""External function interface to MPS libraroes."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin

# pylint: disable=C0103,W0612

def matmul(lhs, rhs, transa=False, transb=False):
    """Create an extern op that compute matrix mult of A and rhs with CrhsLAS

    This function serves as an example on how to calle external libraries.

    Parameters
    ----------
    lhs : Tensor
        The left matrix operand
    rhs : Tensor
        The right matrix operand
    transa : bool
        Whether transpose lhs
    transb : bool
        Whether transpose rhs

    Returns
    -------
    C : Tensor
        The result tensor.
    """
    m = lhs.shape[0] if transa is False else lhs.shape[1]
    n = rhs.shape[1] if transb is False else rhs.shape[0]
    if transa:
        m = b
    if transb:
        n = c
    return _api.extern(
        (m, n), [lhs, rhs],
        lambda ins, outs: _intrin.call_packed(
            "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
        name="C")

def conv2d(data, weight, pad='SAME', stride=1):
    """
    Create an extern op that compute data * weight and return result in output

    Parameters:
    ----------
    data: Tensor
        The input data, format NHWC
    weight: Tensor
        The conv weight, format output_feature * kH * kW * input_feature
    pad: str
        Padding method, 'SAME' or 'VALID'
    stride: int
        convolution stride

    Returns
    -------
    output: Tensor
        The result tensor
    """
    n, hi, wi, ci = data.shape
    co, kh, kw, ciw = weight.shape
    padding = 0 if pad == 'SAME' else 1
    ho = hi // stride
    wo = wi // stride

    return _api.extern(
        (n, ho, wo, co), [data, weight],
        lambda ins, outs: _intrin.call_packed(
            "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
        name="C")