# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule, TypeCall

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
from .function import Function
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op, transform


def get_tensor_array_shape(expr, dtype, prelude):
    """Get the static shape of a tensor array if it has fixed rank shape.

    By design, static ADT tensor in TVM has type name in the format
    of static_tensor_dim0_dim1_..._dimN_t.

    Parameters
    ----------
    expr : Relay Expr
        Input expression.

    dtype : str
        Data type.

    prelude : Prelude
        Tensor array prelude

    Returns
    -------
    shape : tuple of (int, Any) or None
        The output shape. None if input tensor array
        has dynamic shape.
    """
    mod = prelude.mod
    mod["main"] = Function([], expr)
    mod = transform.InferType()(mod)
    checked_type = mod["main"].body.checked_type
    assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
    ta_type_str = checked_type.args[0].func.name_hint
    static_ta_ty_start = "static_tensor_{}".format(dtype)
    if ta_type_str.startswith(static_ta_ty_start):
        shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
            .replace("_t", '')
        shape = []
        if "scalar" not in shape_str:
            for dim_str in shape_str.split("_"):
                if dim_str == "?":
                    shape.append(Any())
                else:
                    shape.append(int(dim_str))
        return tuple(shape)
    return None

def _get_name_static(canonical, dtype, shape):
    """Get name for static shape tensor array op corresponding
    to the canonical name"""
    shape_str = '_'.join([str(dim) for dim in shape])
    if len(shape_str) == 0:
        shape_str = "scalar"
    if canonical == 'tensor_t':
        return 'static_tensor_{}_{}_t'.format(dtype, shape_str)
    return "{}_{}_{}".format(canonical, dtype, shape_str)

class StaticTensorArrayOps(object):
    """Contains tensor array related ops for fixed rank tensor array"""

    def __init__(self, prelude, dtype, shape):
        """Create tensor array ops registry"""
        self.prelude = prelude
        self.dtype = dtype
        self.shape = shape

    def get_name(self, canonical):
        """Get name corresponding to the canonical name"""
        return _get_name_static(canonical, self.dtype, self.shape)

    def get_var(self, canonical):
        """Get var corresponding to the canonical name"""
        name = self.get_name(canonical)
        return getattr(self.prelude, name)

    def define_tensor_adt(self):
        """Defines the static tensor ADT, which is the container for tensors
        with fixed shapes."""
        tensor_type_name = self.get_name('tensor_t')
        # Skip register if tensor type is already registered.
        global_type_names = set()
        for g_ty_var in self.prelude.mod.get_global_type_vars():
            global_type_names.add(g_ty_var.name_hint)
        if tensor_type_name in global_type_names:
            return

        tensor_type_var = GlobalTypeVar(tensor_type_name)
        setattr(self.prelude, tensor_type_name, tensor_type_var)
        tensor_type = TensorType(self.shape, self.dtype)
        tensor_constructor_name = self.get_name('tensor_constructor')

        tensor_nil_name = self.get_name('tensor_nil')
        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
        tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var)

        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
        setattr(self.prelude, tensor_constructor_name, tensor_case)
        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var,
                                                     [],
                                                     [tensor_nil_case, tensor_case])

    def define_tensor_array(self):
        """Defines a function to create a tensor array with size n.
        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
        """
        tensor_array_constructor_name = self.get_name("tensor_array")
        tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name)
        setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
        tensor_nil_var = self.get_var('tensor_nil')
        tensor_type_var = self.get_var('tensor_t')
        n = Var("x", scalar_type('int32'))
        body = If(equal(n, const(0)),
                  self.prelude.nil(),
                  self.prelude.cons(tensor_nil_var(),
                                    tensor_array_constructor_var(subtract(n, const(1)))))
        self.prelude.mod[tensor_array_constructor_var] = \
            Function([n], body, self.prelude.l(tensor_type_var()), [])

    def define_tensor_take(self):
        """Defines a function to return a range of tensor_t on axis 0.
            tensor_take(t, lower, upper) :
            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
        """
        # We don't register take for scalar tensor.
        ndim = len(self.shape)
        if ndim == 0:
            return

        take_name = self.get_name("tensor_take")
        take_var = self._create_global_var(take_name)
        setattr(self.prelude, take_name, take_var)
        origin_tensor_constructor = self.get_var('tensor_constructor')

        output_shape = [Any(),] + list(self.shape[1:])
        tensor_type_var, tensor_constructor = \
            self._get_adt_by_shape(output_shape)

        t = Var('tensor', self.get_var('tensor_t')())
        lower = Var('lower', scalar_type('int32'))
        upper = Var('upper', scalar_type('int32'))
        tvar = Var('t')
        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
                      tensor_constructor(op.take(tvar,
                                                 op.arange(lower, upper, dtype='int32'),
                                                 axis=0)))
        self.prelude.mod[take_var] = \
            Function([t, lower, upper],
                     Match(t, [case], False), tensor_type_var(), [])

    def define_tensor_concatenate(self):
        """Defines a function to concatenate two tensor_t on axis 0.
        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
        """
         # We don't register concatenate for scalar tensor.
        ndim = len(self.shape)
        if ndim == 0:
            return

        concat_name = self.get_name("tensor_concatenate")
        concat_var = self._create_global_var(concat_name)
        setattr(self.prelude, concat_name, concat_var)
        output_shape = [Any(),] + list(self.shape[1:])
        tensor_type_var, tensor_constructor = \
            self._get_adt_by_shape(output_shape)

        origin_tensor_constructor = self.get_var('tensor_constructor')
        origin_tensor_type_var = self.get_var('tensor_t')
        x = Var("x", origin_tensor_type_var())
        y = Var("y", origin_tensor_type_var())
        t1 = Var("t1")
        t2 = Var("t2")

        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]),
                      Match(y,
                            [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]),
                                    tensor_constructor(op.concatenate([t1, t2], axis=0)))],
                            False))

        self.prelude.mod[concat_var] = \
            Function([x, y], Match(x, [case], False), tensor_type_var(), [])


    def define_tensor_expand_dims(self):
        """Defines a function to grow a tensor_t's rank by adding one dimension in front
        of the original tensor_t.
        tensor_expand_dims(t) : tensor_t -> tensor_t
        """
        expand_dims_name = self.get_name("tensor_expand_dims")
        expand_dims_var = self._create_global_var(expand_dims_name)
        setattr(self.prelude, expand_dims_name, expand_dims_var)
        origin_tensor_type_var = self.get_var('tensor_t')
        origin_tensor_constructor = self.get_var('tensor_constructor')
        x = Var("x", origin_tensor_type_var())

        # Note: we set the added axis to be Any() instead of 1 due to
        # in stack op, we need to recursively concatenate.
        tensor_type_var, tensor_constructor = \
            self._get_adt_by_shape([Any(),] + list(self.shape))
        t = Var("t")
        case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t)]),
                      tensor_constructor(op.expand_dims(t, 0, 1)))

        self.prelude.mod[expand_dims_var] = \
            Function([x], Match(x, [case], False), tensor_type_var(), [])

    def define_tensor_array_read(self):
        """Defines a function to get the nth element of a list. Assume the list has at least one
        element.
        tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] ->
        Tensor[self.shape, self.dtype]
        """
        read_name = self.get_name("tensor_array_read")
        read_var = self._create_global_var(read_name)
        setattr(self.prelude, read_name, read_var)
        tensor_type_var = self.get_var('tensor_t')

        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        n = Var("x", scalar_type('int32'))
        self.prelude.mod[read_var] = \
            Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])

    def define_tensor_array_write(self):
        """Defines a function to update a tensor array at index n with value v.
        tensor_array_write(ta, n, v) :
            list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] ->
            list[static_tensor_t]
        """
        write_name = self.get_name("tensor_array_write")
        write_var = self._create_global_var(write_name)
        setattr(self.prelude, write_name, write_var)
        tensor_type_var = self.get_var('tensor_t')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        n = Var("x", scalar_type('int32'))
        v = Var("v", tensor_type_var())
        self.prelude.mod[write_var] = \
            Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
                     self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_unstack(self):
        """Defines a function to unstack the values of a tensor_t in a tensor array.
        tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t]
        """
        ndim = len(self.shape)
        # We don't register unstack for scalar tensor array
        if ndim == 0:
            return

        helper_name = self.get_name("tensor_array_unstack_helper")
        helper_var = self._create_global_var(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType(self.shape, self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))
        tensor_var = Var("tensor", TensorType(self.shape, self.dtype))

        reduced_tensor_type_var, tensor_constructor = \
            self._get_adt_by_shape(self.shape[1:])
        helper_body = \
            If(equal(i, up),
               self.prelude.nil(),
               self.prelude.cons(tensor_constructor(op.take(tensor, i, axis=0)),
                                 helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] = \
            Function([i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), [])

        unstack_name = self.get_name("tensor_array_unstack")
        unstack_var = self._create_global_var(unstack_name)
        setattr(self.prelude, unstack_name, unstack_var)
        shape = op.shape_of(tensor_var)
        unstack_length = op.take(shape, const(0))
        self.prelude.mod[unstack_var] = \
            Function([tensor_var], helper_var(const(0), unstack_length, tensor_var),
                     self.prelude.l(reduced_tensor_type_var()), [])

    def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
        """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
        tensor_array_scatter(ta, indices, value) :
            list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]

        Set static indices shape by specifying indices_shape.
        Set force_update to get static indices shape operator.
        """
        # When this operator has already been registered, only update
        # when force_update is set. This should be used only when we need to
        # redefine this op for static indices shape.
        tensor_array_scatter_name = self.get_name("tensor_array_scatter")
        if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
            return

        tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
        tensor_array_scatter_helper_var = \
            self._create_global_var(tensor_array_scatter_helper_name)
        tensor_type_var = self.get_var('tensor_t')
        ta = Var("ta", self.prelude.l(tensor_type_var()))
        current = Var("current", scalar_type('int32'))
        limit = Var("limit", scalar_type('int32'))
        indices_ = Var('indices_', TensorType(indices_shape or [Any()], 'int32'))
        values_ = Var('values_', self.prelude.l(tensor_type_var()))
        write_var = self.get_var('tensor_array_write')
        read_var = self.get_var('tensor_array_read')
        helper_body = If(equal(current, limit),
                         ta,
                         tensor_array_scatter_helper_var(
                             write_var(ta, op.take(indices_, current),
                                       read_var(values_, current)),
                             add(current, const(1)),
                             limit, indices_, values_))
        self.prelude.mod[tensor_array_scatter_helper_var] = \
            Function([ta, current, limit, indices_, values_],
                     helper_body, self.prelude.l(tensor_type_var()), [])

        tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name)
        setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))

        indices = Var('indices', TensorType(indices_shape or [Any()], 'int32'))
        values = Var('values', self.prelude.l(tensor_type_var()))
        if indices_shape is None:
            indices_shape = op.shape_of(indices)
            limit = op.take(indices_shape, const(0))
        else:
            limit = const(indices_shape[0])

        body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
        self.prelude.mod[tensor_array_scatter_var] = \
            Function([tensor_array, indices, values], body,
                     self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_split(self,
                                  value_shape=None,
                                  lengths_shape=None,
                                  force_update=False):
        """Defines a function to split the values of a tensor_t into a tensor array.
        tensor_array_split(ta, value, lengths) :
            list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]

        Set static value and lengths shapes by specifying value_shape and lengths_shape.
        Set force_update to get static value and lengths shape operator.
        """
        # Skip scalar case
        ndim = len(self.shape)
        if ndim == 0:
            return

        # When this operator has already been registered, only update
        # when force_update is set. This should be used only when we need to
        # redefine this op for static value/indices shape.
        split_name = self.get_name("tensor_array_split")
        if hasattr(self.prelude, split_name) and not force_update:
            return

        tensor_type_var = self.get_var('tensor_t')
        tensor_array_split_helper_name = self.get_name("ta_split_helper")
        tensor_array_split_helper_var = \
            self._create_global_var(tensor_array_split_helper_name)
        setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
        output_shape = [Any(),] + list(self.shape[1:])
        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)

        if value_shape is None:
            value_type_var = tensor_type_var
            take_var = self.get_var('tensor_take')
        else:
            value_type_var, _ = self._get_adt_by_shape(value_shape)
            # Also get static shape take operator
            origin_shape = list(self.shape)
            self.shape = value_shape
            self.define_tensor_take()
            take_var = self.get_var('tensor_take')
            self.shape = origin_shape


        ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
        value1 = Var('value1', value_type_var())
        offset1 = Var('offset1', scalar_type('int32'))
        current1 = Var('current1', scalar_type('int32'))
        limit1 = Var('limit1', scalar_type('int32'))
        lengths1 = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))

        # Register write for output shape
        origin_shape = list(self.shape)
        self.shape = output_shape
        self.define_tensor_array_write()
        write_var = self.get_var('tensor_array_write')
        self.shape = origin_shape
        helper1_body = If(equal(current1, limit1),
                          ta1,
                          write_var(
                              tensor_array_split_helper_var(
                                  ta1,
                                  value1,
                                  add(offset1, op.take(lengths1, current1)),
                                  add(current1, const(1)),
                                  limit1,
                                  lengths1
                              ),
                              current1,
                              take_var(value1,
                                       offset1,
                                       add(op.take(lengths1, current1), offset1))))
        self.prelude.mod[tensor_array_split_helper_var] = \
            Function([ta1, value1, offset1, current1, limit1, lengths1],
                     helper1_body, self.prelude.l(output_tensor_type_var()), [])
        split_var = self._create_global_var(split_name)
        setattr(self.prelude, split_name, split_var)
        tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var()))

        value = Var('value', value_type_var())
        lengths = Var('lengths', TensorType(lengths_shape or [Any()], 'int32'))
        if lengths_shape is None:
            lengths_shape = op.shape_of(lengths)
            lengths_limit = op.take(lengths_shape, const(0))
        else:
            lengths_limit = const(lengths_shape[0])
        body = tensor_array_split_helper_var(
            tensor_array,
            value,
            const(0),
            const(0),
            lengths_limit,
            lengths)

        self.prelude.mod[split_var] = \
            Function([tensor_array, value, lengths], body,
                     self.prelude.l(output_tensor_type_var()), [])

    def define_tensor_array_concat(self):
        """Defines a function to return the values in the tensor array as concatenated tensor_t.
        tensor_array_concat(ta) : list[tensor_t] -> tensor_t
        """
        # We don't register concat for scalar tensor array.
        ndim = len(self.shape)
        if ndim == 0:
            return

        concat_name = self.get_name("tensor_array_concat")
        concat_var = self._create_global_var(concat_name)
        setattr(self.prelude, concat_name, concat_var)

        output_shape = [Any(),] + list(self.shape[1:])
        tensor_type_var, _ = self._get_adt_by_shape(output_shape)

        # Register tensor concatenate and get tensor_nil var for output shape
        origin_shape = self.shape
        self.shape = output_shape
        self.define_tensor_concatenate()
        tensor_concat_var = self.get_var('tensor_concatenate')
        tensor_nil_var = self.get_var('tensor_nil')
        self.shape = origin_shape

        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        hd = Var("hd")
        tl = Var("tl")
        nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
        cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
                           Match(tl, [
                               Clause(PatternConstructor(self.prelude.nil), hd),
                               Clause(PatternWildcard(),
                                      tensor_concat_var(hd, concat_var(tl)))
                           ], False))
        self.prelude.mod[concat_var] = \
            Function([tensor_array],
                     Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), [])

    def define_tensor_array_stack(self):
        """Defines a function to get the values in the tensor array as a stack tensor_t.
        tensor_array_stack(l) : list[tensor_t] -> tensor_t
        """
        stack_name = self.get_name("tensor_array_stack")
        stack_var = self._create_global_var(stack_name)
        setattr(self.prelude, stack_name, stack_var)
        tensor_type_var = self.get_var('tensor_t')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        expand_dims_var = self.get_var('tensor_expand_dims')

        # Register tensor_concatenate for output_shape
        origin_shape = self.shape
        output_shape = [Any(),] + list(self.shape)
        self.shape = output_shape
        self.define_tensor_concatenate()
        concat_var = self.get_var('tensor_concatenate')
        self.shape = origin_shape

        tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
        tensors = self.prelude.foldl(concat_var,
                                     self.prelude.hd(tensor_array_expand_dims),
                                     self.prelude.tl(tensor_array_expand_dims))
        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
        self.prelude.mod[stack_var] = Function([tensor_array], tensors,
                                               output_tensor_type_var(), [])

    def define_tensor_array_gather(self):
        """Defines a function to return the selected values in a tensor array as tensor_t.
        tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
        """
        helper_name = self.get_name("tensor_array_gather_helper")
        helper_var = self._create_global_var(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor_type_var = self.get_var('tensor_t')
        output_shape = [Any(),] + list(self.shape)
        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
        stack_var = self.get_var('tensor_array_stack')
        read_var = self.get_var('tensor_array_read')
        ta = Var("ta", self.prelude.l(tensor_type_var()))
        accu = Var("accu", self.prelude.l(tensor_type_var()))
        current = Var("current", scalar_type('int32'))
        limit = Var("limit", scalar_type('int32'))
        indices_ = Var('indices_', TensorType([Any()], 'int32'))
        helper_body = \
            If(equal(current, const(0)),
               stack_var(accu),
               helper_var(
                   ta,
                   self.prelude.cons(
                       read_var(
                           ta, op.take(indices_, subtract(current, const(1)))), accu),
                   subtract(current, const(1)),
                   limit, indices_))
        self.prelude.mod[helper_var] = \
            Function([ta, accu, current, limit, indices_],
                     helper_body, output_tensor_type_var(), [])
        gather_name = self.get_name("tensor_array_gather")
        gather_var = self._create_global_var(gather_name)
        setattr(self.prelude, gather_name, gather_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        indices = Var('indices', TensorType([Any()], 'int32'))
        indices_shape = op.shape_of(indices)
        limit = op.take(indices_shape, const(0))
        body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
        self.prelude.mod[gather_var] = \
            Function([tensor_array, indices], body, output_tensor_type_var(), [])

    def define_tensor_get_data(self, data_shape):
        """Defines a function to get a Tensor from tensor_t with given shape.
        """
        tensor_get_data_name = self.get_name("tensor_get_data")
        tensor_get_data_var = self._create_global_var(tensor_get_data_name)
        setattr(self.prelude, tensor_get_data_name, tensor_get_data_var)

        tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape)
        t = Var('tensor', tensor_type_var())
        tvar = Var('t')
        case =\
            Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
        self.prelude.mod[tensor_get_data_var] = \
            Function([t], Match(t, [case], False),
                     TensorType(data_shape, self.dtype), [])

    def register(self):
        """Register all tensor array ops in Prelude"""
        self.define_tensor_adt()
        self.define_tensor_take()
        self.define_tensor_concatenate()
        self.define_tensor_expand_dims()
        self.define_tensor_array()
        self.define_tensor_array_read()
        self.define_tensor_array_write()
        self.define_tensor_array_unstack()
        self.define_tensor_array_scatter()
        self.define_tensor_array_split()
        self.define_tensor_array_concat()
        self.define_tensor_array_stack()
        self.define_tensor_array_gather()

    def _get_adt_by_shape(self, shape):
        """Get ADT type and constructor with given shape."""
        origin_shape = self.shape
        self.shape = shape
        self.define_tensor_adt()
        tensor_type_var = self.get_var("tensor_t")
        tensor_constructor = self.get_var("tensor_constructor")
        self.shape = origin_shape
        return tensor_type_var, tensor_constructor

    def _create_global_var(self, name):
        """Create a GlobalVar if doesn't exist in prelude."""
        global_var_name_set = set()
        for g_var_name in self.prelude.mod.get_global_vars():
            global_var_name_set.add(g_var_name.name_hint)
        if name not in global_var_name_set:
            gvar = GlobalVar(name)
        else:
            gvar = self.prelude.mod.get_global_var(name)

        return gvar

class TensorArrayOps(object):
    """Contains tensor array related ops"""

    def __init__(self, prelude, dtype):
        """Create tensor array ops registry"""
        self.prelude = prelude
        self.dtype = dtype

    def get_name(self, canonical):
        """Get name corresponding to the canonical name"""
        return self.prelude.get_name(canonical, self.dtype)

    def get_var(self, canonical):
        """Get var corresponding to the canonical name"""
        return self.prelude.get_var(canonical, self.dtype)

    def define_tensor_adt(self):
        """Defines the dynamic tensor ADT, which is the container for tensors
        with variable shapes."""
        tensor_type_name = self.get_name('tensor_t')
        tensor_type_var = GlobalTypeVar(tensor_type_name)
        setattr(self.prelude, tensor_type_name, tensor_type_var)
        tensor0_type = TensorType([], self.dtype)
        tensor1_type = TensorType([Any()], self.dtype)
        tensor2_type = TensorType([Any(), Any()], self.dtype)
        tensor3_type = TensorType([Any(), Any(), Any()], self.dtype)
        tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype)
        tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)
        tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)
        tensor_nil_name = self.get_name('tensor_nil')
        tensor0_name = self.get_name('tensor0')
        tensor1_name = self.get_name('tensor1')
        tensor2_name = self.get_name('tensor2')
        tensor3_name = self.get_name('tensor3')
        tensor4_name = self.get_name('tensor4')
        tensor5_name = self.get_name('tensor5')
        tensor6_name = self.get_name('tensor6')
        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
        tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var)
        tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var)
        tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var)
        tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var)
        tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var)
        tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var)
        tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var)
        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
        setattr(self.prelude, tensor0_name, tensor0_case)
        setattr(self.prelude, tensor1_name, tensor1_case)
        setattr(self.prelude, tensor2_name, tensor2_case)
        setattr(self.prelude, tensor3_name, tensor3_case)
        setattr(self.prelude, tensor4_name, tensor4_case)
        setattr(self.prelude, tensor5_name, tensor5_case)
        setattr(self.prelude, tensor6_name, tensor6_case)
        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case,
                                                                           tensor0_case,
                                                                           tensor1_case,
                                                                           tensor2_case,
                                                                           tensor3_case,
                                                                           tensor4_case,
                                                                           tensor5_case,
                                                                           tensor6_case])

    def define_tensor_take(self):
        """Defines a function to return a range of tensor_t on axis 0.
            tensor_take(t, lower, upper) :
            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
        """
        take_name = self.get_name("tensor_take")
        take_var = GlobalVar(take_name)
        setattr(self.prelude, take_name, take_var)
        tensor_t = self.get_var('tensor_t')
        tensor1_var = self.get_var('tensor1')
        tensor2_var = self.get_var('tensor2')
        tensor3_var = self.get_var('tensor3')
        tensor4_var = self.get_var('tensor4')
        tensor5_var = self.get_var('tensor5')
        tensor6_var = self.get_var('tensor6')
        t = Var('tensor', tensor_t())
        lower = Var('lower', scalar_type('int32'))
        upper = Var('upper', scalar_type('int32'))
        t1 = Var('t1')
        t2 = Var('t2')
        t3 = Var('t3')
        t4 = Var('t4')
        t5 = Var('t5')
        t6 = Var('t6')
        tensor1_case =\
            Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
                   tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32'))))
        tensor2_case =\
            Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
                   tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor3_case =\
            Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
                   tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor4_case =\
            Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
                   tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor5_case =\
            Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
                   tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor6_case =\
            Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]),
                   tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0)))
        self.prelude.mod[take_var] =\
            Function([t, lower, upper],
                     Match(t, [tensor1_case,
                               tensor2_case,
                               tensor3_case,
                               tensor4_case,
                               tensor5_case,
                               tensor6_case], False),
                     tensor_t(), [])

    def define_tensor_expand_dims(self):
        """Defines a function to grow a tensor_t's rank by adding one dimension in front
        of the original tensor_t.
        tensor_expand_dims(t) : tensor_t -> tensor_t
        """
        expand_dims_name = self.get_name("tensor_expand_dims")
        expand_dims_var = GlobalVar(expand_dims_name)
        setattr(self.prelude, expand_dims_name, expand_dims_var)
        tensor_type_var = self.get_var('tensor_t')
        x = Var("x", tensor_type_var())
        t0 = Var("t0")
        t1 = Var("t1")
        t2 = Var("t2")
        t3 = Var("t3")
        t4 = Var("t4")
        t5 = Var("t5")
        tensor0_var = self.get_var('tensor0')
        tensor1_var = self.get_var('tensor1')
        tensor2_var = self.get_var('tensor2')
        tensor3_var = self.get_var('tensor3')
        tensor4_var = self.get_var('tensor4')
        tensor5_var = self.get_var('tensor5')
        tensor6_var = self.get_var('tensor6')
        tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]),
                              tensor1_var(op.expand_dims(t0, 0, 1)))
        tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
                              tensor2_var(op.expand_dims(t1, 0, 1)))
        tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
                              tensor3_var(op.expand_dims(t2, 0, 1)))
        tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
                              tensor4_var(op.expand_dims(t3, 0, 1)))
        tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
                              tensor5_var(op.expand_dims(t4, 0, 1)))
        tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
                              tensor6_var(op.expand_dims(t5, 0, 1)))
        self.prelude.mod[expand_dims_var] =\
            Function([x],
                     Match(x, [tensor0_case,
                               tensor1_case,
                               tensor2_case,
                               tensor3_case,
                               tensor4_case,
                               tensor5_case], False))

    def define_tensor_concat(self):
        """Defines a function to concatenate two tensor_t on the first axis

        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
        """
        concat_name = self.get_name("tensor_concatenate")
        concat_var = GlobalVar(concat_name)
        setattr(self.prelude, concat_name, concat_var)
        tensor_type_var = self.get_var('tensor_t')
        x = Var("x", tensor_type_var())
        y = Var("y", tensor_type_var())

        tensor1_var = self.get_var('tensor1')
        tensor2_var = self.get_var('tensor2')
        tensor3_var = self.get_var('tensor3')
        tensor4_var = self.get_var('tensor4')
        t11 = Var("t11")
        t12 = Var("t12")
        t21 = Var("t21")
        t22 = Var("t22")
        t31 = Var("t31")
        t32 = Var("t32")
        t41 = Var("t41")
        t42 = Var("t42")
        tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]),
                              Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]),
                                               tensor1_var(op.concatenate([t11, t12], axis=0)))],
                                    False))
        tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]),
                              Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]),
                                               tensor2_var(op.concatenate([t21, t22], axis=0)))],
                                    False))
        tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]),
                              Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]),
                                               tensor3_var(op.concatenate([t31, t32], axis=0)))],
                                    False))
        tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]),
                              Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]),
                                               tensor4_var(op.concatenate([t41, t42], axis=0)))],
                                    False))
        # op.concatenate does not support tensor with rank higher than 4
        self.prelude.mod[concat_var] =\
            Function([x, y], Match(x, [tensor1_case,
                                       tensor2_case,
                                       tensor3_case,
                                       tensor4_case], False))

    def define_tensor_array(self):
        """Defines a function to create a tensor array with size n.
        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
        """
        tensor_array_constructor_name = self.get_name("tensor_array")
        tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name)
        setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
        tensor_nil_var = self.get_var('tensor_nil')
        tensor_type_var = self.get_var('tensor_t')
        n = Var("x", scalar_type('int32'))
        body = If(equal(n, const(0)),
                  self.prelude.nil(),
                  self.prelude.cons(tensor_nil_var(),
                                    tensor_array_constructor_var(subtract(n, const(1)))))
        self.prelude.mod[tensor_array_constructor_var] = \
            Function([n], body, self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_read(self):
        """Defines a function to get the head of a list. Assume the list has at least one
        element.

        tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t
        """
        read_name = self.get_name("tensor_array_read")
        read_var = GlobalVar(read_name)
        setattr(self.prelude, read_name, read_var)
        tensor_type_var = self.get_var('tensor_t')

        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        n = Var("x", scalar_type('int32'))
        self.prelude.mod[read_var] =\
            Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])

    def define_tensor_array_write(self):
        """Defines a function to update a tensor array at index n with value v.
        tensor_array_write(ta, n, v) :
            list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t]
        """
        write_name = self.get_name("tensor_array_write")
        write_var = GlobalVar(write_name)
        setattr(self.prelude, write_name, write_var)
        tensor_type_var = self.get_var('tensor_t')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        n = Var("x", scalar_type('int32'))
        v = Var("v", tensor_type_var())
        self.prelude.mod[write_var] =\
            Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
                     self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_unstack_tensor1(self):
        """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array.
        tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor1_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))
        tensor_type_var = self.get_var('tensor_t')
        tensor0_var = self.get_var('tensor0')
        helper_body =\
            If(equal(i, up),
               self.prelude.nil(),
               self.prelude.cons(tensor0_var(op.take(tensor, i)),
                                 helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), [])
        unstack_name = self.get_name("tensor_array_unstack_tensor1")
        unstack_var = GlobalVar(unstack_name)
        setattr(self.prelude, unstack_name, unstack_var)
        tensor1 = Var("tensor", TensorType([Any()], self.dtype))
        shape = op.shape_of(tensor1)
        ndim = op.take(shape, const(0))
        self.prelude.mod[unstack_var] =\
            Function([tensor1], helper_var(const(0), ndim, tensor1),
                     self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_unstack_tensor2(self):
        """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array.

        tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor2_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2")
        tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name)
        setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var)
        tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype))
        shape = op.shape_of(tensor2)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor2_var] =\
            Function([tensor2], helper_var(const(0), ndim, tensor2),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor3(self):
        """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.

        tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor3_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor2')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3")
        tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name)
        setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var)
        tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor3)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor3_var] =\
            Function([tensor3], helper_var(const(0), ndim, tensor3),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor4(self):
        """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.

        tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor4_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor3')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4")
        tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name)
        setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var)
        tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor4)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor4_var] =\
            Function([tensor4], helper_var(const(0), ndim, tensor4),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor5(self):
        """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.

        tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor5_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor4')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5")
        tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name)
        setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var)
        tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor5)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor5_var] =\
            Function([tensor5], helper_var(const(0), ndim, tensor5),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor6(self):
        """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.

        tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor6_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor5')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6")
        tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name)
        setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var)
        tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor6)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor6_var] =\
            Function([tensor6], helper_var(const(0), ndim, tensor6),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_scatter(self):
        """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
        tensor_array_scatter(ta, indices, value) :
            list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
        """
        tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
        tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name)
        tensor_t = self.get_var('tensor_t')
        ta = Var("ta", self.prelude.l(tensor_t()))
        current = Var("current", scalar_type('int32'))
        limit = Var("limit", scalar_type('int32'))
        indices_ = Var('indices_', TensorType([Any()], 'int32'))
        values_ = Var('values_', self.prelude.l(tensor_t()))
        write_var = self.get_var('tensor_array_write')
        read_var = self.get_var('tensor_array_read')
        helper_body = If(equal(current, limit),
                         ta,
                         tensor_array_scatter_helper_var(
                             write_var(ta, op.take(indices_, current),
                                       read_var(values_, current)),
                             add(current, const(1)),
                             limit, indices_, values_))
        self.prelude.mod[tensor_array_scatter_helper_var] =\
            Function([ta, current, limit, indices_, values_],
                     helper_body, self.prelude.l(tensor_t()), [])
        tensor_array_scatter_name = self.get_name("tensor_array_scatter")
        tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name)
        setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
        indices = Var('indices', TensorType([Any()], 'int32'))
        values = Var('values', self.prelude.l(tensor_t()))
        indices_shape = op.shape_of(indices)
        limit = op.take(indices_shape, const(0))
        body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
        self.prelude.mod[tensor_array_scatter_var] =\
            Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), [])

    def define_tensor_array_split(self):
        """Defines a function to split the values of a tensor_t into a tensor array.
        tensor_array_split(ta, value, lengths) :
            list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
        """
        tensor_t = self.get_var('tensor_t')
        tensor_array_split_helper_name = self.get_name("ta_split_helper")
        tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name)
        setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
        ta1 = Var("tensor_array", self.prelude.l(tensor_t()))
        value1 = Var('value1', tensor_t())
        offset1 = Var('offset1', scalar_type('int32'))
        current1 = Var('current1', scalar_type('int32'))
        limit1 = Var('limit1', scalar_type('int32'))
        lengths1 = Var('lengths', TensorType([Any()], 'int32'))
        write_var = self.get_var('tensor_array_write')
        take_var = self.get_var('tensor_take')
        helper1_body = If(equal(current1, limit1),
                          ta1,
                          write_var(
                              tensor_array_split_helper_var(
                                  ta1,
                                  value1,
                                  add(offset1, op.take(lengths1, current1)),
                                  add(current1, const(1)),
                                  limit1,
                                  lengths1
                              ),
                              current1,
                              take_var(value1,
                                       offset1,
                                       add(op.take(lengths1, current1), offset1))))
        self.prelude.mod[tensor_array_split_helper_var] = \
            Function([ta1, value1, offset1, current1, limit1, lengths1],
                     helper1_body, self.prelude.l(tensor_t()), [])
        split_name = self.get_name("tensor_array_split")
        split_var = GlobalVar(split_name)
        setattr(self.prelude, split_name, split_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
        value = Var('value', tensor_t())
        lengths = Var('lengths', TensorType([Any()], 'int32'))
        lengths_shape = op.shape_of(lengths)
        lengths_limit = op.take(lengths_shape, const(0))
        body = tensor_array_split_helper_var(
            tensor_array,
            value,
            const(0),
            const(0),
            lengths_limit,
            lengths)
        self.prelude.mod[split_var] =\
            Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), [])

    def define_tensor_array_concat(self):
        """Defines a function to return the values in the tensor array as concatenated tensor_t.
        tensor_array_concat(ta) : list[tensor_t] -> tensor_t
        """
        concat_name = self.get_name("tensor_array_concat")
        concat_var = GlobalVar(concat_name)
        setattr(self.prelude, concat_name, concat_var)
        tensor_concat_var = self.get_var('tensor_concatenate')
        tensor_t = self.get_var('tensor_t')
        tensor_nil_var = self.get_var('tensor_nil')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
        hd = Var("hd")
        tl = Var("tl")
        nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
        cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
                           Match(tl, [
                               Clause(PatternConstructor(self.prelude.nil), hd),
                               Clause(PatternWildcard(),
                                      tensor_concat_var(hd, concat_var(tl)))
                           ], False))
        self.prelude.mod[concat_var] =\
            Function([tensor_array],
                     Match(tensor_array, [nil_case, cons_case], False), tensor_t(), [])

    def define_tensor_array_gather(self):
        """Defines a function to return the selected values in a tensor array as tensor_t.
        tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
        """
        helper_name = self.get_name("tensor_array_gather_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor_type_var = self.get_var('tensor_t')
        stack_var = self.get_var('tensor_array_stack')
        read_var = self.get_var('tensor_array_read')
        ta = Var("ta", self.prelude.l(tensor_type_var()))
        accu = Var("accu", self.prelude.l(tensor_type_var()))
        current = Var("current", scalar_type('int32'))
        limit = Var("limit", scalar_type('int32'))
        indices_ = Var('indices_', TensorType([Any()], 'int32'))
        helper_body =\
            If(equal(current, const(0)),
               stack_var(accu),
               helper_var(
                   ta,
                   self.prelude.cons(
                       read_var(
                           ta, op.take(indices_, subtract(current, const(1)))), accu),
                   subtract(current, const(1)),
                   limit, indices_))
        self.prelude.mod[helper_var] = \
            Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), [])
        gather_name = self.get_name("tensor_array_gather")
        gather_var = GlobalVar(gather_name)
        setattr(self.prelude, gather_name, gather_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        indices = Var('indices', TensorType([Any()], 'int32'))
        indices_shape = op.shape_of(indices)
        limit = op.take(indices_shape, const(0))
        body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
        self.prelude.mod[gather_var] =\
            Function([tensor_array, indices], body, tensor_type_var(), [])

    def define_tensor_array_stack(self):
        """Defines a function to get the values in the tensor array as a stack tensor_t.
        tensor_array_stack(l) : list[tensor_t] -> tensor_t
        """
        stack_name = self.get_name("tensor_array_stack")
        stack_var = GlobalVar(stack_name)
        setattr(self.prelude, stack_name, stack_var)
        tensor_type_var = self.get_var('tensor_t')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        expand_dims_var = self.get_var('tensor_expand_dims')
        concat_var = self.get_var('tensor_concatenate')
        tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
        tensors = self.prelude.foldl(concat_var,
                                     self.prelude.hd(tensor_array_expand_dims),
                                     self.prelude.tl(tensor_array_expand_dims))
        self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), [])

    def register(self):
        """Register all tensor array ops in Prelude"""
        self.define_tensor_adt()
        self.define_tensor_take()
        self.define_tensor_expand_dims()
        self.define_tensor_concat()
        self.define_tensor_array()
        self.define_tensor_array_read()
        self.define_tensor_array_write()
        self.define_tensor_array_unstack_tensor1()
        self.define_tensor_array_unstack_tensor2()
        self.define_tensor_array_unstack_tensor3()
        self.define_tensor_array_unstack_tensor4()
        self.define_tensor_array_unstack_tensor5()
        self.define_tensor_array_unstack_tensor6()
        self.define_tensor_array_scatter()
        self.define_tensor_array_split()
        self.define_tensor_array_concat()
        self.define_tensor_array_stack()
        # TODO(wweic): Gather fails in PartialEvaluate
        # self.define_tensor_array_gather()

class Prelude:
    """Contains standard definitions."""

    def __init__(self, mod=None):
        if mod is None:
            mod = IRModule()
        self.mod = mod
        self.load_prelude()

    def get_name(self, canonical, dtype):
        """Get name corresponding to the canonical name"""
        if canonical == 'tensor_t':
            return 'tensor_{}_t'.format(dtype)
        return "{}_{}".format(canonical, dtype)

    def get_var(self, canonical, dtype):
        """Get var corresponding to the canonical name"""
        name = self.get_name(canonical, dtype)
        return getattr(self, name)

    def get_name_static(self, canonical, dtype, shape):
        """Get name corresponding to the canonical name"""
        return _get_name_static(canonical, dtype, shape)

    def get_var_static(self, canonical, dtype, shape):
        """Get var corresponding to the canonical name"""
        name = self.get_name_static(canonical, dtype, shape)
        return getattr(self, name)

    def load_prelude(self):
        """Parses the Prelude from Relay's text format into a module."""
        # TODO(@jroesch): we should remove this helper when we port over prelude
        self.mod.import_from_std("prelude.rly")

        self.l = self.mod.get_global_type_var("List")
        list_adt = self.mod[self.l]
        self.cons = list_adt.constructors[0]
        self.nil = list_adt.constructors[1]

        self.optional = self.mod.get_global_type_var("Option")
        optional_adt = self.mod[self.optional]
        self.some = optional_adt.constructors[0]
        self.none = optional_adt.constructors[1]

        self.tree = self.mod.get_global_type_var("Tree")
        tree_adt = self.mod[self.tree]
        self.rose = tree_adt.constructors[0]

        GLOBAL_DEFS = [
            "id",
            "compose",
            "flip",
            "hd",
            "tl",
            "nth",
            "update",
            "map",
            "foldl",
            "foldr",
            "foldr1",
            "concat",
            "filter",
            "zip",
            "rev",
            "map_accuml",
            "map_accumr",
            "unfoldl",
            "unfoldr",
            "sum",
            "length",
            "tmap",
            "size",
            "iterate",
        ]
        for global_def in GLOBAL_DEFS:
            setattr(self, global_def, self.mod.get_global_var(global_def))

        for dtype in ['float32',
                      'float16',
                      'float64',
                      'int32',
                      'uint8',
                      'int8',
                      'int16',
                      'uint16',
                      'int64']:
            tensor_array_ops = TensorArrayOps(self, dtype)
            tensor_array_ops.register()