adt.cc 5.43 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 *  Copyright (c) 2019 by Contributors
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
 * \file src/tvm/ir/adt.cc
 * \brief AST nodes for Relay algebraic data types (ADTs).
 */
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>

namespace tvm {
namespace relay {

PatternWildcard PatternWildcardNode::make() {
  NodePtr<PatternWildcardNode> n = make_node<PatternWildcardNode>();
  return PatternWildcard(n);
}

TVM_REGISTER_NODE_TYPE(PatternWildcardNode);

TVM_REGISTER_API("relay._make.PatternWildcard")
39
.set_body_typed(PatternWildcardNode::make);
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
                                      tvm::IRPrinter* p) {
  p->stream << "PatternWildcardNode()";
});

PatternVar PatternVarNode::make(tvm::relay::Var var) {
  NodePtr<PatternVarNode> n = make_node<PatternVarNode>();
  n->var = std::move(var);
  return PatternVar(n);
}

TVM_REGISTER_NODE_TYPE(PatternVarNode);

TVM_REGISTER_API("relay._make.PatternVar")
56
.set_body_typed(PatternVarNode::make);
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node,
                                 tvm::IRPrinter* p) {
  p->stream << "PatternVarNode(" << node->var << ")";
});

PatternConstructor PatternConstructorNode::make(Constructor constructor,
                                                tvm::Array<Pattern> patterns) {
  NodePtr<PatternConstructorNode> n = make_node<PatternConstructorNode>();
  n->constructor = std::move(constructor);
  n->patterns = std::move(patterns);
  return PatternConstructor(n);
}

TVM_REGISTER_NODE_TYPE(PatternConstructorNode);

TVM_REGISTER_API("relay._make.PatternConstructor")
75
.set_body_typed(PatternConstructorNode::make);
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
                                         tvm::IRPrinter* p) {
  p->stream << "PatternConstructorNode(" << node->constructor
            << ", " << node->patterns << ")";
});

Constructor ConstructorNode::make(std::string name_hint,
                                  tvm::Array<Type> inputs,
                                  GlobalTypeVar belong_to) {
  NodePtr<ConstructorNode> n = make_node<ConstructorNode>();
  n->name_hint = std::move(name_hint);
  n->inputs = std::move(inputs);
  n->belong_to = std::move(belong_to);
  return Constructor(n);
}

TVM_REGISTER_NODE_TYPE(ConstructorNode);

TVM_REGISTER_API("relay._make.Constructor")
97
.set_body_typed(ConstructorNode::make);
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node,
                                  tvm::IRPrinter* p) {
  p->stream << "ConstructorNode(" << node->name_hint << ", "
            << node->inputs << ", " << node->belong_to << ")";
});

TypeData TypeDataNode::make(GlobalTypeVar header,
                            tvm::Array<TypeVar> type_vars,
                            tvm::Array<Constructor> constructors) {
  NodePtr<TypeDataNode> n = make_node<TypeDataNode>();
  n->header = std::move(header);
  n->type_vars = std::move(type_vars);
  n->constructors = std::move(constructors);
  return TypeData(n);
}

TVM_REGISTER_NODE_TYPE(TypeDataNode);

TVM_REGISTER_API("relay._make.TypeData")
119
.set_body_typed(TypeDataNode::make);
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node,
                               tvm::IRPrinter* p) {
  p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
            << node->constructors << ")";
});

Clause ClauseNode::make(Pattern lhs, Expr rhs) {
  NodePtr<ClauseNode> n = make_node<ClauseNode>();
  n->lhs = std::move(lhs);
  n->rhs = std::move(rhs);
  return Clause(n);
}

TVM_REGISTER_NODE_TYPE(ClauseNode);

TVM_REGISTER_API("relay._make.Clause")
138
.set_body_typed(ClauseNode::make);
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node,
                             tvm::IRPrinter* p) {
  p->stream << "ClauseNode(" << node->lhs << ", "
            << node->rhs << ")";
  });

Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
  NodePtr<MatchNode> n = make_node<MatchNode>();
  n->data = std::move(data);
  n->clauses = std::move(clauses);
  return Match(n);
}

TVM_REGISTER_NODE_TYPE(MatchNode);

TVM_REGISTER_API("relay._make.Match")
157
.set_body_typed(MatchNode::make);
158 159 160 161 162 163 164 165 166 167

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node,
                            tvm::IRPrinter* p) {
  p->stream << "MatchNode(" << node->data << ", "
            << node->clauses << ")";
});

}  // namespace relay
}  // namespace tvm