# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
from .function import Function
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op


class TensorArrayOps(object):
    """Contains tensor array related ops"""

    def __init__(self, prelude, dtype):
        """Create tensor array ops registry"""
        self.prelude = prelude
        self.dtype = dtype

    def get_name(self, canonical):
        """Get name corresponding to the canonical name"""
        return self.prelude.get_name(canonical, self.dtype)

    def get_var(self, canonical):
        """Get var corresponding to the canonical name"""
        return self.prelude.get_var(canonical, self.dtype)

    def define_tensor_adt(self):
        """Defines the dynamic tensor ADT, which is the container for tensors
        with variable shapes."""
        tensor_type_name = self.get_name('tensor_t')
        tensor_type_var = GlobalTypeVar(tensor_type_name)
        setattr(self.prelude, tensor_type_name, tensor_type_var)
        tensor0_type = TensorType([], self.dtype)
        tensor1_type = TensorType([Any()], self.dtype)
        tensor2_type = TensorType([Any(), Any()], self.dtype)
        tensor3_type = TensorType([Any(), Any(), Any()], self.dtype)
        tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype)
        tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)
        tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)
        tensor_nil_name = self.get_name('tensor_nil')
        tensor0_name = self.get_name('tensor0')
        tensor1_name = self.get_name('tensor1')
        tensor2_name = self.get_name('tensor2')
        tensor3_name = self.get_name('tensor3')
        tensor4_name = self.get_name('tensor4')
        tensor5_name = self.get_name('tensor5')
        tensor6_name = self.get_name('tensor6')
        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
        tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var)
        tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var)
        tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var)
        tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var)
        tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var)
        tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var)
        tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var)
        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
        setattr(self.prelude, tensor0_name, tensor0_case)
        setattr(self.prelude, tensor1_name, tensor1_case)
        setattr(self.prelude, tensor2_name, tensor2_case)
        setattr(self.prelude, tensor3_name, tensor3_case)
        setattr(self.prelude, tensor4_name, tensor4_case)
        setattr(self.prelude, tensor5_name, tensor5_case)
        setattr(self.prelude, tensor6_name, tensor6_case)
        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case,
                                                                           tensor0_case,
                                                                           tensor1_case,
                                                                           tensor2_case,
                                                                           tensor3_case,
                                                                           tensor4_case,
                                                                           tensor5_case,
                                                                           tensor6_case])

    def define_tensor_take(self):
        """Defines a function to return a range of tensor_t on axis 0.
            tensor_take(t, lower, upper) :
            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
        """
        take_name = self.get_name("tensor_take")
        take_var = GlobalVar(take_name)
        setattr(self.prelude, take_name, take_var)
        tensor_t = self.get_var('tensor_t')
        tensor1_var = self.get_var('tensor1')
        tensor2_var = self.get_var('tensor2')
        tensor3_var = self.get_var('tensor3')
        tensor4_var = self.get_var('tensor4')
        tensor5_var = self.get_var('tensor5')
        tensor6_var = self.get_var('tensor6')
        t = Var('tensor', tensor_t())
        lower = Var('lower', scalar_type('int32'))
        upper = Var('upper', scalar_type('int32'))
        t1 = Var('t1')
        t2 = Var('t2')
        t3 = Var('t3')
        t4 = Var('t4')
        t5 = Var('t5')
        t6 = Var('t6')
        tensor1_case =\
            Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
                   tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32'))))
        tensor2_case =\
            Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
                   tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor3_case =\
            Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
                   tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor4_case =\
            Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
                   tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor5_case =\
            Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
                   tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0)))
        tensor6_case =\
            Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]),
                   tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0)))
        self.prelude.mod[take_var] =\
            Function([t, lower, upper],
                     Match(t, [tensor1_case,
                               tensor2_case,
                               tensor3_case,
                               tensor4_case,
                               tensor5_case,
                               tensor6_case], False),
                     tensor_t(), [])

    def define_tensor_expand_dims(self):
        """Defines a function to grow a tensor_t's rank by adding one dimension in front
        of the original tensor_t.
        tensor_expand_dims(t) : tensor_t -> tensor_t
        """
        expand_dims_name = self.get_name("tensor_expand_dims")
        expand_dims_var = GlobalVar(expand_dims_name)
        setattr(self.prelude, expand_dims_name, expand_dims_var)
        tensor_type_var = self.get_var('tensor_t')
        x = Var("x", tensor_type_var())
        t0 = Var("t0")
        t1 = Var("t1")
        t2 = Var("t2")
        t3 = Var("t3")
        t4 = Var("t4")
        t5 = Var("t5")
        tensor0_var = self.get_var('tensor0')
        tensor1_var = self.get_var('tensor1')
        tensor2_var = self.get_var('tensor2')
        tensor3_var = self.get_var('tensor3')
        tensor4_var = self.get_var('tensor4')
        tensor5_var = self.get_var('tensor5')
        tensor6_var = self.get_var('tensor6')
        tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]),
                              tensor1_var(op.expand_dims(t0, 0, 1)))
        tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
                              tensor2_var(op.expand_dims(t1, 0, 1)))
        tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
                              tensor3_var(op.expand_dims(t2, 0, 1)))
        tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
                              tensor4_var(op.expand_dims(t3, 0, 1)))
        tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
                              tensor5_var(op.expand_dims(t4, 0, 1)))
        tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
                              tensor6_var(op.expand_dims(t5, 0, 1)))
        self.prelude.mod[expand_dims_var] =\
            Function([x],
                     Match(x, [tensor0_case,
                               tensor1_case,
                               tensor2_case,
                               tensor3_case,
                               tensor4_case,
                               tensor5_case], False))

    def define_tensor_concat(self):
        """Defines a function to concatenate two tensor_t on the first axis

        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
        """
        concat_name = self.get_name("tensor_concatenate")
        concat_var = GlobalVar(concat_name)
        setattr(self.prelude, concat_name, concat_var)
        tensor_type_var = self.get_var('tensor_t')
        x = Var("x", tensor_type_var())
        y = Var("y", tensor_type_var())

        tensor1_var = self.get_var('tensor1')
        tensor2_var = self.get_var('tensor2')
        tensor3_var = self.get_var('tensor3')
        tensor4_var = self.get_var('tensor4')
        t11 = Var("t11")
        t12 = Var("t12")
        t21 = Var("t21")
        t22 = Var("t22")
        t31 = Var("t31")
        t32 = Var("t32")
        t41 = Var("t41")
        t42 = Var("t42")
        tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]),
                              Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]),
                                               tensor1_var(op.concatenate([t11, t12], axis=0)))],
                                    False))
        tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]),
                              Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]),
                                               tensor2_var(op.concatenate([t21, t22], axis=0)))],
                                    False))
        tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]),
                              Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]),
                                               tensor3_var(op.concatenate([t31, t32], axis=0)))],
                                    False))
        tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]),
                              Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]),
                                               tensor4_var(op.concatenate([t41, t42], axis=0)))],
                                    False))
        # op.concatenate does not support tensor with rank higher than 4
        self.prelude.mod[concat_var] =\
            Function([x, y], Match(x, [tensor1_case,
                                       tensor2_case,
                                       tensor3_case,
                                       tensor4_case], False))

    def define_tensor_array(self):
        """Defines a function to create a tensor array with size n.
        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
        """
        tensor_array_constructor_name = self.get_name("tensor_array")
        tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name)
        setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
        tensor_nil_var = self.get_var('tensor_nil')
        tensor_type_var = self.get_var('tensor_t')
        n = Var("x", scalar_type('int32'))
        body = If(equal(n, const(0)),
                  self.prelude.nil(),
                  self.prelude.cons(tensor_nil_var(),
                                    tensor_array_constructor_var(subtract(n, const(1)))))
        self.prelude.mod[tensor_array_constructor_var] = \
            Function([n], body, self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_read(self):
        """Defines a function to get the head of a list. Assume the list has at least one
        element.

        tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t
        """
        read_name = self.get_name("tensor_array_read")
        read_var = GlobalVar(read_name)
        setattr(self.prelude, read_name, read_var)
        tensor_type_var = self.get_var('tensor_t')

        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        n = Var("x", scalar_type('int32'))
        self.prelude.mod[read_var] =\
            Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])

    def define_tensor_array_write(self):
        """Defines a function to update a tensor array at index n with value v.
        tensor_array_write(ta, n, v) :
            list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t]
        """
        write_name = self.get_name("tensor_array_write")
        write_var = GlobalVar(write_name)
        setattr(self.prelude, write_name, write_var)
        tensor_type_var = self.get_var('tensor_t')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        n = Var("x", scalar_type('int32'))
        v = Var("v", tensor_type_var())
        self.prelude.mod[write_var] =\
            Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
                     self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_unstack_tensor1(self):
        """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array.
        tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor1_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))
        tensor_type_var = self.get_var('tensor_t')
        tensor0_var = self.get_var('tensor0')
        helper_body =\
            If(equal(i, up),
               self.prelude.nil(),
               self.prelude.cons(tensor0_var(op.take(tensor, i)),
                                 helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), [])
        unstack_name = self.get_name("tensor_array_unstack_tensor1")
        unstack_var = GlobalVar(unstack_name)
        setattr(self.prelude, unstack_name, unstack_var)
        tensor1 = Var("tensor", TensorType([Any()], self.dtype))
        shape = op.shape_of(tensor1)
        ndim = op.take(shape, const(0))
        self.prelude.mod[unstack_var] =\
            Function([tensor1], helper_var(const(0), ndim, tensor1),
                     self.prelude.l(tensor_type_var()), [])

    def define_tensor_array_unstack_tensor2(self):
        """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array.

        tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor2_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2")
        tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name)
        setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var)
        tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype))
        shape = op.shape_of(tensor2)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor2_var] =\
            Function([tensor2], helper_var(const(0), ndim, tensor2),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor3(self):
        """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.

        tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor3_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor2')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3")
        tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name)
        setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var)
        tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor3)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor3_var] =\
            Function([tensor3], helper_var(const(0), ndim, tensor3),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor4(self):
        """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.

        tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor4_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor3')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4")
        tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name)
        setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var)
        tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor4)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor4_var] =\
            Function([tensor4], helper_var(const(0), ndim, tensor4),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor5(self):
        """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.

        tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor5_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor4')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5")
        tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name)
        setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var)
        tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor5)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor5_var] =\
            Function([tensor5], helper_var(const(0), ndim, tensor5),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_unstack_tensor6(self):
        """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.

        tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t]
        """
        helper_name = self.get_name("tensor_array_unstack_tensor6_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
        up = Var("up", scalar_type('int32'))
        i = Var("i", scalar_type('int32'))

        helper_body = If(equal(i, up),
                         self.prelude.nil(),
                         self.prelude.cons(self.get_var('tensor5')(op.take(tensor, i, axis=0)),
                                           helper_var(add(i, const(1)), up, tensor)))
        self.prelude.mod[helper_var] =\
            Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])

        tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6")
        tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name)
        setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var)
        tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
        shape = op.shape_of(tensor6)
        ndim = op.take(shape, const(0))
        self.prelude.mod[tensor_array_unstack_tensor6_var] =\
            Function([tensor6], helper_var(const(0), ndim, tensor6),
                     self.prelude.l(self.get_var('tensor_t')()), [])

    def define_tensor_array_scatter(self):
        """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
        tensor_array_scatter(ta, indices, value) :
            list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
        """
        tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
        tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name)
        tensor_t = self.get_var('tensor_t')
        ta = Var("ta", self.prelude.l(tensor_t()))
        current = Var("current", scalar_type('int32'))
        limit = Var("limit", scalar_type('int32'))
        indices_ = Var('indices_', TensorType([Any()], 'int32'))
        values_ = Var('values_', self.prelude.l(tensor_t()))
        write_var = self.get_var('tensor_array_write')
        read_var = self.get_var('tensor_array_read')
        helper_body = If(equal(current, limit),
                         ta,
                         tensor_array_scatter_helper_var(
                             write_var(ta, op.take(indices_, current),
                                       read_var(values_, current)),
                             add(current, const(1)),
                             limit, indices_, values_))
        self.prelude.mod[tensor_array_scatter_helper_var] =\
            Function([ta, current, limit, indices_, values_],
                     helper_body, self.prelude.l(tensor_t()), [])
        tensor_array_scatter_name = self.get_name("tensor_array_scatter")
        tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name)
        setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
        indices = Var('indices', TensorType([Any()], 'int32'))
        values = Var('values', self.prelude.l(tensor_t()))
        indices_shape = op.shape_of(indices)
        limit = op.take(indices_shape, const(0))
        body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
        self.prelude.mod[tensor_array_scatter_var] =\
            Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), [])

    def define_tensor_array_split(self):
        """Defines a function to split the values of a tensor_t into a tensor array.
        tensor_array_split(ta, value, lengths) :
            list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
        """
        tensor_t = self.get_var('tensor_t')
        tensor_array_split_helper_name = self.get_name("ta_split_helper")
        tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name)
        setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
        ta1 = Var("tensor_array", self.prelude.l(tensor_t()))
        value1 = Var('value1', tensor_t())
        offset1 = Var('offset1', scalar_type('int32'))
        current1 = Var('current1', scalar_type('int32'))
        limit1 = Var('limit1', scalar_type('int32'))
        lengths1 = Var('lengths', TensorType([Any()], 'int32'))
        write_var = self.get_var('tensor_array_write')
        take_var = self.get_var('tensor_take')
        helper1_body = If(equal(current1, limit1),
                          ta1,
                          write_var(
                              tensor_array_split_helper_var(
                                  ta1,
                                  value1,
                                  add(offset1, op.take(lengths1, current1)),
                                  add(current1, const(1)),
                                  limit1,
                                  lengths1
                              ),
                              current1,
                              take_var(value1,
                                       offset1,
                                       add(op.take(lengths1, current1), offset1))))
        self.prelude.mod[tensor_array_split_helper_var] = \
            Function([ta1, value1, offset1, current1, limit1, lengths1],
                     helper1_body, self.prelude.l(tensor_t()), [])
        split_name = self.get_name("tensor_array_split")
        split_var = GlobalVar(split_name)
        setattr(self.prelude, split_name, split_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
        value = Var('value', tensor_t())
        lengths = Var('lengths', TensorType([Any()], 'int32'))
        lengths_shape = op.shape_of(lengths)
        lengths_limit = op.take(lengths_shape, const(0))
        body = tensor_array_split_helper_var(
            tensor_array,
            value,
            const(0),
            const(0),
            lengths_limit,
            lengths)
        self.prelude.mod[split_var] =\
            Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), [])

    def define_tensor_array_concat(self):
        """Defines a function to return the values in the tensor array as concatenated tensor_t.
        tensor_array_concat(ta) : list[tensor_t] -> tensor_t
        """
        concat_name = self.get_name("tensor_array_concat")
        concat_var = GlobalVar(concat_name)
        setattr(self.prelude, concat_name, concat_var)
        tensor_concat_var = self.get_var('tensor_concatenate')
        tensor_t = self.get_var('tensor_t')
        tensor_nil_var = self.get_var('tensor_nil')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
        hd = Var("hd")
        tl = Var("tl")
        nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
        cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
                           Match(tl, [
                               Clause(PatternConstructor(self.prelude.nil), hd),
                               Clause(PatternWildcard(),
                                      tensor_concat_var(hd, concat_var(tl)))
                           ], False))
        self.prelude.mod[concat_var] =\
            Function([tensor_array],
                     Match(tensor_array, [nil_case, cons_case], False), tensor_t(), [])

    def define_tensor_array_gather(self):
        """Defines a function to return the selected values in a tensor array as tensor_t.
        tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
        """
        helper_name = self.get_name("tensor_array_gather_helper")
        helper_var = GlobalVar(helper_name)
        setattr(self.prelude, helper_name, helper_var)
        tensor_type_var = self.get_var('tensor_t')
        stack_var = self.get_var('tensor_array_stack')
        read_var = self.get_var('tensor_array_read')
        ta = Var("ta", self.prelude.l(tensor_type_var()))
        accu = Var("accu", self.prelude.l(tensor_type_var()))
        current = Var("current", scalar_type('int32'))
        limit = Var("limit", scalar_type('int32'))
        indices_ = Var('indices_', TensorType([Any()], 'int32'))
        helper_body =\
            If(equal(current, const(0)),
               stack_var(accu),
               helper_var(
                   ta,
                   self.prelude.cons(
                       read_var(
                           ta, op.take(indices_, subtract(current, const(1)))), accu),
                   subtract(current, const(1)),
                   limit, indices_))
        self.prelude.mod[helper_var] = \
            Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), [])
        gather_name = self.get_name("tensor_array_gather")
        gather_var = GlobalVar(gather_name)
        setattr(self.prelude, gather_name, gather_var)
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        indices = Var('indices', TensorType([Any()], 'int32'))
        indices_shape = op.shape_of(indices)
        limit = op.take(indices_shape, const(0))
        body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
        self.prelude.mod[gather_var] =\
            Function([tensor_array, indices], body, tensor_type_var(), [])

    def define_tensor_array_stack(self):
        """Defines a function to get the values in the tensor array as a stack tensor_t.
        tensor_array_stack(l) : list[tensor_t] -> tensor_t
        """
        stack_name = self.get_name("tensor_array_stack")
        stack_var = GlobalVar(stack_name)
        setattr(self.prelude, stack_name, stack_var)
        tensor_type_var = self.get_var('tensor_t')
        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
        expand_dims_var = self.get_var('tensor_expand_dims')
        concat_var = self.get_var('tensor_concatenate')
        tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
        tensors = self.prelude.foldl(concat_var,
                                     self.prelude.hd(tensor_array_expand_dims),
                                     self.prelude.tl(tensor_array_expand_dims))
        self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), [])

    def register(self):
        """Register all tensor array ops in Prelude"""
        self.define_tensor_adt()
        self.define_tensor_take()
        self.define_tensor_expand_dims()
        self.define_tensor_concat()
        self.define_tensor_array()
        self.define_tensor_array_read()
        self.define_tensor_array_write()
        self.define_tensor_array_unstack_tensor1()
        self.define_tensor_array_unstack_tensor2()
        self.define_tensor_array_unstack_tensor3()
        self.define_tensor_array_unstack_tensor4()
        self.define_tensor_array_unstack_tensor5()
        self.define_tensor_array_unstack_tensor6()
        self.define_tensor_array_scatter()
        self.define_tensor_array_split()
        self.define_tensor_array_concat()
        self.define_tensor_array_stack()
        # TODO(wweic): Gather fails in PartialEvaluate
        # self.define_tensor_array_gather()

class Prelude:
    """Contains standard definitions."""

    def __init__(self, mod=None):
        if mod is None:
            mod = IRModule()
        self.mod = mod
        self.load_prelude()

    def get_name(self, canonical, dtype):
        """Get name corresponding to the canonical name"""
        if canonical == 'tensor_t':
            return 'tensor_{}_t'.format(dtype)
        return "{}_{}".format(canonical, dtype)

    def get_var(self, canonical, dtype):
        """Get var corresponding to the canonical name"""
        name = self.get_name(canonical, dtype)
        return getattr(self, name)

    def load_prelude(self):
        """Parses the Prelude from Relay's text format into a module."""
        # TODO(@jroesch): we should remove this helper when we port over prelude
        self.mod.import_from_std("prelude.rly")

        self.l = self.mod.get_global_type_var("List")
        list_adt = self.mod[self.l]
        self.cons = list_adt.constructors[0]
        self.nil = list_adt.constructors[1]

        self.optional = self.mod.get_global_type_var("Option")
        optional_adt = self.mod[self.optional]
        self.some = optional_adt.constructors[0]
        self.none = optional_adt.constructors[1]

        self.tree = self.mod.get_global_type_var("Tree")
        tree_adt = self.mod[self.tree]
        self.rose = tree_adt.constructors[0]

        GLOBAL_DEFS = [
            "id",
            "compose",
            "flip",
            "hd",
            "tl",
            "nth",
            "update",
            "map",
            "foldl",
            "foldr",
            "foldr1",
            "concat",
            "filter",
            "zip",
            "rev",
            "map_accuml",
            "map_accumr",
            "unfoldl",
            "unfoldr",
            "sum",
            "length",
            "tmap",
            "size",
            "iterate",
        ]
        for global_def in GLOBAL_DEFS:
            setattr(self, global_def, self.mod.get_global_var(global_def))

        for dtype in ['float32',
                      'float16',
                      'float64',
                      'int32',
                      'uint8',
                      'int8',
                      'int16',
                      'uint16',
                      'int64']:
            tensor_array_ops = TensorArrayOps(self, dtype)
            tensor_array_ops.register()