/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/tvm/ir/type.cc
 * \brief The type system AST nodes of Relay.
 */
#include <tvm/relay/type.h>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace tvm::runtime;

TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
  NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
  n->shape = std::move(shape);
  n->dtype = std::move(dtype);
  return TensorType(n);
}

TensorType TensorTypeNode::Scalar(DataType dtype) {
  return TensorTypeNode::make({}, dtype);
}

IndexExpr TensorTypeNode::Size() const {
  if (shape.size() == 0) {
    return make_const(Int(64), 1);
  }

  IndexExpr size = shape[0];
  for (size_t i = 1; i < shape.size(); ++i) {
    size *= shape[i];
  }
  return size;
}

TVM_REGISTER_NODE_TYPE(TensorTypeNode);

TVM_REGISTER_API("relay._make.TensorType")
.set_body_typed(TensorTypeNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
                                 tvm::IRPrinter* p) {
  p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});

TypeVar TypeVarNode::make(std::string name, Kind kind) {
  NodePtr<TypeVarNode> n = make_node<TypeVarNode>();
  n->var = tvm::Var(name);
  n->kind = std::move(kind);
  return TypeVar(n);
}

TVM_REGISTER_NODE_TYPE(TypeVarNode);

TVM_REGISTER_API("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
    return TypeVarNode::make(name, static_cast<Kind>(kind));
    });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const TypeVarNode* node,
                                    tvm::IRPrinter* p) {
  p->stream << "TypeVarNode(" << node->var->name_hint << ", "
    << node->kind << ")";
});

GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
  NodePtr<GlobalTypeVarNode> n = make_node<GlobalTypeVarNode>();
  n->var = tvm::Var(name);
  n->kind = std::move(kind);
  return GlobalTypeVar(n);
}

TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);

TVM_REGISTER_API("relay._make.GlobalTypeVar")
.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
    return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
    });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
                                    tvm::IRPrinter *p) {
  p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
            << node->kind << ")";
});

TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
  NodePtr<TypeCallNode> n = make_node<TypeCallNode>();
  n->func = std::move(func);
  n->args = std::move(args);
  return TypeCall(n);
}

TVM_REGISTER_NODE_TYPE(TypeCallNode);

TVM_REGISTER_API("relay._make.TypeCall")
.set_body_typed(TypeCallNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node,
                               tvm::IRPrinter* p) {
  p->stream << "TypeCallNode(" << node->func << ", "
            << node->args << ")";
});

IncompleteType IncompleteTypeNode::make(Kind kind) {
  auto n = make_node<IncompleteTypeNode>();
  n->kind = std::move(kind);
  return IncompleteType(n);
}

TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);

TVM_REGISTER_API("relay._make.IncompleteType")
.set_body_typed<IncompleteType(int)>([](int kind) {
    return IncompleteTypeNode::make(static_cast<Kind>(kind));
  });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>(
    [](const IncompleteTypeNode* node,
       tvm::IRPrinter* p) {
      p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
    });

FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
                            Type ret_type,
                            tvm::Array<TypeVar> type_params,
                            tvm::Array<TypeConstraint> type_constraints) {
  NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
  n->arg_types = std::move(arg_types);
  n->ret_type = std::move(ret_type);
  n->type_params = std::move(type_params);
  n->type_constraints = std::move(type_constraints);
  return FuncType(n);
}

TVM_REGISTER_NODE_TYPE(FuncTypeNode);

TVM_REGISTER_API("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
                                   tvm::IRPrinter* p) {
  p->stream << "FuncTypeNode(" << node->type_params << ", "
            << node->arg_types << ", " << node->ret_type << ", "
            << node->type_constraints << ")";
});

TypeRelation TypeRelationNode::make(TypeRelationFn func,
                                    Array<Type> args,
                                    int num_inputs,
                                    Attrs attrs) {
  NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
  n->func = std::move(func);
  n->args = std::move(args);
  n->num_inputs = num_inputs;
  n->attrs = std::move(attrs);
  return TypeRelation(n);
}

TVM_REGISTER_NODE_TYPE(TypeRelationNode);

TVM_REGISTER_API("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
    p->stream << "TypeRelationNode("
              << node->func->name
              << ", " << node->args << ")";
});

TupleType TupleTypeNode::make(Array<Type> fields) {
  NodePtr<TupleTypeNode> n = make_node<TupleTypeNode>();
  n->fields = std::move(fields);
  return TupleType(n);
}

TVM_REGISTER_NODE_TYPE(TupleTypeNode);

TVM_REGISTER_API("relay._make.TupleType")
.set_body_typed(TupleTypeNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
                                tvm::IRPrinter* p) {
  p->stream << "TupleTypeNode(" << node->fields << ")";
});

RefType RefTypeNode::make(Type value) {
  NodePtr<RefTypeNode> n = make_node<RefTypeNode>();
  n->value = std::move(value);
  return RefType(n);
}

TVM_REGISTER_API("relay._make.RefType")
.set_body_typed(RefTypeNode::make);

TVM_REGISTER_NODE_TYPE(RefTypeNode);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefTypeNode>([](const RefTypeNode* node,
                              tvm::IRPrinter* p) {
  p->stream << "RefTypeNode(" << node->value << ")";
});

}  // namespace relay
}  // namespace tvm