adt.cc 4.71 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file src/ir/adt.cc
22 23 24 25 26 27 28 29
 * \brief AST nodes for Relay algebraic data types (ADTs).
 */
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>

namespace tvm {
namespace relay {

30
PatternWildcard::PatternWildcard() {
31
  ObjectPtr<PatternWildcardNode> n = make_object<PatternWildcardNode>();
32
  data_ = std::move(n);
33 34 35 36
}

TVM_REGISTER_NODE_TYPE(PatternWildcardNode);

37
TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard")
38 39 40
.set_body_typed([]() {
  return PatternWildcard();
});
41

42 43
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
44 45 46
  p->stream << "PatternWildcardNode()";
});

47
PatternVar::PatternVar(tvm::relay::Var var) {
48
  ObjectPtr<PatternVarNode> n = make_object<PatternVarNode>();
49
  n->var = std::move(var);
50
  data_ = std::move(n);
51 52 53 54
}

TVM_REGISTER_NODE_TYPE(PatternVarNode);

55
TVM_REGISTER_GLOBAL("relay.ir.PatternVar")
56 57 58
.set_body_typed([](tvm::relay::Var var) {
  return PatternVar(var);
});
59

60 61
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
62
  auto* node = static_cast<const PatternVarNode*>(ref.get());
63 64 65
  p->stream << "PatternVarNode(" << node->var << ")";
});

66 67
PatternConstructor::PatternConstructor(Constructor constructor,
                                       tvm::Array<Pattern> patterns) {
68
  ObjectPtr<PatternConstructorNode> n = make_object<PatternConstructorNode>();
69 70
  n->constructor = std::move(constructor);
  n->patterns = std::move(patterns);
71
  data_ = std::move(n);
72 73 74 75
}

TVM_REGISTER_NODE_TYPE(PatternConstructorNode);

76
TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor")
77 78 79
.set_body_typed([](Constructor constructor, tvm::Array<Pattern> patterns) {
  return PatternConstructor(constructor, patterns);
});
80

81 82
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
83
  auto* node = static_cast<const PatternConstructorNode*>(ref.get());
84 85 86 87
  p->stream << "PatternConstructorNode(" << node->constructor
            << ", " << node->patterns << ")";
});

88
PatternTuple::PatternTuple(tvm::Array<Pattern> patterns) {
89
  ObjectPtr<PatternTupleNode> n = make_object<PatternTupleNode>();
90
  n->patterns = std::move(patterns);
91
  data_ = std::move(n);
92 93 94 95
}

TVM_REGISTER_NODE_TYPE(PatternTupleNode);

96
TVM_REGISTER_GLOBAL("relay.ir.PatternTuple")
97 98 99
.set_body_typed([](tvm::Array<Pattern> patterns) {
  return PatternTuple(patterns);
});
100

101 102
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
103
  auto* node = static_cast<const PatternTupleNode*>(ref.get());
104 105 106
  p->stream << "PatternTupleNode(" << node->patterns << ")";
});

107
Clause::Clause(Pattern lhs, Expr rhs) {
108
  ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
109 110
  n->lhs = std::move(lhs);
  n->rhs = std::move(rhs);
111
  data_ = std::move(n);
112 113 114 115
}

TVM_REGISTER_NODE_TYPE(ClauseNode);

116
TVM_REGISTER_GLOBAL("relay.ir.Clause")
117 118 119
.set_body_typed([](Pattern lhs, Expr rhs) {
  return Clause(lhs, rhs);
});
120

121 122
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
123
    auto* node = static_cast<const ClauseNode*>(ref.get());
124 125 126 127
  p->stream << "ClauseNode(" << node->lhs << ", "
            << node->rhs << ")";
  });

128
Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete) {
129
  ObjectPtr<MatchNode> n = make_object<MatchNode>();
130 131
  n->data = std::move(data);
  n->clauses = std::move(clauses);
132
  n->complete = complete;
133
  data_ = std::move(n);
134 135 136 137
}

TVM_REGISTER_NODE_TYPE(MatchNode);

138
TVM_REGISTER_GLOBAL("relay.ir.Match")
139 140 141
.set_body_typed([](Expr data, tvm::Array<Clause> clauses, bool complete) {
  return Match(data, clauses, complete);
});
142

143 144
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
145
  auto* node = static_cast<const MatchNode*>(ref.get());
146
  p->stream << "MatchNode(" << node->data << ", "
147
            << node->clauses << ", " << node->complete << ")";
148 149 150 151
});

}  // namespace relay
}  // namespace tvm