/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/tvm/ir/adt.cc
 * \brief AST nodes for Relay algebraic data types (ADTs).
 */
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>

namespace tvm {
namespace relay {

PatternWildcard PatternWildcardNode::make() {
  ObjectPtr<PatternWildcardNode> n = make_object<PatternWildcardNode>();
  return PatternWildcard(n);
}

TVM_REGISTER_NODE_TYPE(PatternWildcardNode);

TVM_REGISTER_GLOBAL("relay._make.PatternWildcard")
.set_body_typed(PatternWildcardNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, NodePrinter* p) {
  p->stream << "PatternWildcardNode()";
});

PatternVar PatternVarNode::make(tvm::relay::Var var) {
  ObjectPtr<PatternVarNode> n = make_object<PatternVarNode>();
  n->var = std::move(var);
  return PatternVar(n);
}

TVM_REGISTER_NODE_TYPE(PatternVarNode);

TVM_REGISTER_GLOBAL("relay._make.PatternVar")
.set_body_typed(PatternVarNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, NodePrinter* p) {
  auto* node = static_cast<const PatternVarNode*>(ref.get());
  p->stream << "PatternVarNode(" << node->var << ")";
});

PatternConstructor PatternConstructorNode::make(Constructor constructor,
                                                tvm::Array<Pattern> patterns) {
  ObjectPtr<PatternConstructorNode> n = make_object<PatternConstructorNode>();
  n->constructor = std::move(constructor);
  n->patterns = std::move(patterns);
  return PatternConstructor(n);
}

TVM_REGISTER_NODE_TYPE(PatternConstructorNode);

TVM_REGISTER_GLOBAL("relay._make.PatternConstructor")
.set_body_typed(PatternConstructorNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
  auto* node = static_cast<const PatternConstructorNode*>(ref.get());
  p->stream << "PatternConstructorNode(" << node->constructor
            << ", " << node->patterns << ")";
});

PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
  ObjectPtr<PatternTupleNode> n = make_object<PatternTupleNode>();
  n->patterns = std::move(patterns);
  return PatternTuple(n);
}

TVM_REGISTER_NODE_TYPE(PatternTupleNode);

TVM_REGISTER_GLOBAL("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, NodePrinter* p) {
  auto* node = static_cast<const PatternTupleNode*>(ref.get());
  p->stream << "PatternTupleNode(" << node->patterns << ")";
});

Constructor ConstructorNode::make(std::string name_hint,
                                  tvm::Array<Type> inputs,
                                  GlobalTypeVar belong_to) {
  ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
  n->name_hint = std::move(name_hint);
  n->inputs = std::move(inputs);
  n->belong_to = std::move(belong_to);
  return Constructor(n);
}

TVM_REGISTER_NODE_TYPE(ConstructorNode);

TVM_REGISTER_GLOBAL("relay._make.Constructor")
.set_body_typed(ConstructorNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
  auto* node = static_cast<const ConstructorNode*>(ref.get());
  p->stream << "ConstructorNode(" << node->name_hint << ", "
            << node->inputs << ", " << node->belong_to << ")";
});

TypeData TypeDataNode::make(GlobalTypeVar header,
                            tvm::Array<TypeVar> type_vars,
                            tvm::Array<Constructor> constructors) {
  ObjectPtr<TypeDataNode> n = make_object<TypeDataNode>();
  n->header = std::move(header);
  n->type_vars = std::move(type_vars);
  n->constructors = std::move(constructors);
  return TypeData(n);
}

TVM_REGISTER_NODE_TYPE(TypeDataNode);

TVM_REGISTER_GLOBAL("relay._make.TypeData")
.set_body_typed(TypeDataNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
  auto* node = static_cast<const TypeDataNode*>(ref.get());
  p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
            << node->constructors << ")";
});

Clause ClauseNode::make(Pattern lhs, Expr rhs) {
  ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
  n->lhs = std::move(lhs);
  n->rhs = std::move(rhs);
  return Clause(n);
}

TVM_REGISTER_NODE_TYPE(ClauseNode);

TVM_REGISTER_GLOBAL("relay._make.Clause")
.set_body_typed(ClauseNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, NodePrinter* p) {
    auto* node = static_cast<const ClauseNode*>(ref.get());
  p->stream << "ClauseNode(" << node->lhs << ", "
            << node->rhs << ")";
  });

Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
  ObjectPtr<MatchNode> n = make_object<MatchNode>();
  n->data = std::move(data);
  n->clauses = std::move(clauses);
  n->complete = complete;
  return Match(n);
}

TVM_REGISTER_NODE_TYPE(MatchNode);

TVM_REGISTER_GLOBAL("relay._make.Match")
.set_body_typed(MatchNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, NodePrinter* p) {
  auto* node = static_cast<const MatchNode*>(ref.get());
  p->stream << "MatchNode(" << node->data << ", "
            << node->clauses << ", " << node->complete << ")";
});

}  // namespace relay
}  // namespace tvm