expr.cc 11.4 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file src/tvm/relay/ir/expr.cc
22 23
 * \brief The expression AST nodes of Relay.
 */
24
#include <tvm/ir/module.h>
25 26 27 28 29
#include <tvm/relay/expr.h>

namespace tvm {
namespace relay {

30
using tvm::ReprPrinter;
31 32 33
using namespace tvm::runtime;

Constant ConstantNode::make(runtime::NDArray data) {
34
  ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
35 36 37 38
  n->data = std::move(data);
  return Constant(n);
}

39 40
TVM_REGISTER_NODE_TYPE(ConstantNode);

41
TVM_REGISTER_GLOBAL("relay._make.Constant")
42
.set_body_typed(ConstantNode::make);
43

44 45
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
46
    auto* node = static_cast<const ConstantNode*>(ref.get());
47 48 49 50
    const PackedFunc* fprint = Registry::Get("relay._constant_repr");
    CHECK(fprint) << "unable to find printing function for constants";
    std::string data = (*fprint)(GetRef<Constant>(node));
    p->stream << "Constant(" << data << ")";
51
  });
52 53

TensorType ConstantNode::tensor_type() const {
54
  auto dtype = DataType(data->dtype);
55
  Array<tvm::PrimExpr> shape;
56
  for (int i = 0; i < data->ndim; i++) {
57 58 59
    CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
    CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
    shape.push_back(
60
        tvm::IntImm(DataType::Int(32), data->shape[i]));
61 62
  }

63
  return TensorType(shape, dtype);
64 65 66
}

Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
67
  ObjectPtr<TupleNode> n = make_object<TupleNode>();
68 69 70 71
  n->fields = std::move(fields);
  return Tuple(n);
}

72 73
TVM_REGISTER_NODE_TYPE(TupleNode);

74
TVM_REGISTER_GLOBAL("relay._make.Tuple")
75
.set_body_typed(TupleNode::make);
76

77 78
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
79
    auto* node = static_cast<const TupleNode*>(ref.get());
80 81
    p->stream << "Tuple(" << node->fields << ")";
  });
82

83 84

Var VarNode::make(Id vid, Type type_annotation) {
85
  ObjectPtr<VarNode> n = make_object<VarNode>();
86
  n->vid = std::move(vid);
87
  n->type_annotation = std::move(type_annotation);
88 89 90
  return Var(n);
}

91
Var VarNode::make(std::string name_hint, Type type_annotation) {
92
  ObjectPtr<IdNode> n = make_object<IdNode>();
93 94 95 96
  n->name_hint = std::move(name_hint);
  return VarNode::make(Id(n), type_annotation);
}

97 98
TVM_REGISTER_NODE_TYPE(VarNode);

99
TVM_REGISTER_GLOBAL("relay._make.Var")
100
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
101

102 103
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
104
    auto* node = static_cast<const VarNode*>(ref.get());
105
    p->stream << "Var(" << node->name_hint();
106 107
    if (node->type_annotation.defined()) {
      p->stream << ", ty=";
108
      p->Print(node->type_annotation);
109 110
    }
    p->stream << ")";
111
  });
112

113 114
Function FunctionNode::make(tvm::Array<Var> params,
                            Expr body,
115
                            Type ret_type,
116 117
                            tvm::Array<TypeVar> type_params,
                            tvm::Attrs attrs) {
118
  ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
119 120
  CHECK(params.defined());
  CHECK(type_params.defined());
121 122
  n->params = std::move(params);
  n->body = std::move(body);
123
  n->ret_type = std::move(ret_type);
124
  n->type_params = std::move(type_params);
125
  n->attrs = std::move(attrs);
126 127 128
  return Function(n);
}

129
FuncType FunctionNode::func_type_annotation() const {
130 131
  Array<Type> param_types;
  for (auto param : this->params) {
132
    Type param_type = (param->type_annotation.defined()) ? param->type_annotation
133
      : IncompleteType(Kind::kType);
134
    param_types.push_back(param_type);
135
  }
136 137

  Type ret_type = (this->ret_type.defined()) ? this->ret_type
138
    : IncompleteType(Kind::kType);
139
  return FuncType(param_types, ret_type, this->type_params, {});
140 141
}

142
bool FunctionNode::IsPrimitive() const {
143
  ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kPrimitive);
144
  const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
145 146 147
  return pval && pval->value != 0;
}

148
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
Zhi committed
149
  return FunctionSetAttr(GetRef<Function>(this), attr::kParams, parameters);
150 151
}

152
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams")
153
.set_body_typed(
154 155 156 157 158
  [](const Function& func, const tvm::Map<Var, Constant>& parameters) {
    return func->SetParams(parameters);
});

tvm::Map<Var, Constant> FunctionNode::GetParams() const {
Zhi committed
159
  auto node_ref = FunctionGetAttr(GetRef<Function>(this), attr::kParams);
160 161 162
  return Downcast<tvm::Map<Var, Constant>>(node_ref);
}

163
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams")
164
.set_body_typed([](const Function& func) {
165 166 167
  return func->GetParams();
});

Zhi committed
168
bool FunctionNode::UseDefaultCompiler() const {
169
  ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kCompiler);
170
  const tir::StringImmNode* pval = res.as<tir::StringImmNode>();
Zhi committed
171 172 173
  return pval == nullptr || pval->value == "default";
}

174 175
ObjectRef FunctionGetAttr(const Function& func, const std::string& key) {
  if (!func->attrs.defined()) { return ObjectRef(); }
176 177 178 179 180 181 182

  const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
  CHECK(dict_attrs);
  auto it = dict_attrs->dict.find(key);
  if (it != dict_attrs->dict.end()) {
    return (*it).second;
  } else {
183
    return ObjectRef();
184 185 186
  }
}

187
Function FunctionSetAttr(const Function& func, const std::string& key, const ObjectRef& data) {
188 189 190
  const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
  Attrs func_attrs;
  if (dattrs) {
191
    Map<std::string, ObjectRef> dict = dattrs->dict;
192 193 194
    dict.Set(key, data);
    func_attrs = DictAttrsNode::make(dict);
  } else {
195
    Map<std::string, ObjectRef> dict = {{key, data}};
196 197 198 199 200 201 202 203 204 205 206
    func_attrs = DictAttrsNode::make(dict);
  }

  return FunctionNode::make(
    func->params,
    func->body,
    func->ret_type,
    func->type_params,
    func_attrs);
}

207 208
TVM_REGISTER_NODE_TYPE(FunctionNode);

209
TVM_REGISTER_GLOBAL("relay._make.Function")
210
.set_body_typed(FunctionNode::make);
211

212 213
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
214 215 216 217
  auto* node = static_cast<const FunctionNode*>(ref.get());
  p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
            << ", " << node->body << ", " << node->type_params << ", "
            << node->attrs << ")";
218 219 220 221
});

Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
                    Array<Type> type_args) {
222
  ObjectPtr<CallNode> n = make_object<CallNode>();
223 224 225 226 227 228 229
  n->op = std::move(op);
  n->args = std::move(args);
  n->attrs = std::move(attrs);
  n->type_args = std::move(type_args);
  return Call(n);
}

230 231
TVM_REGISTER_NODE_TYPE(CallNode);

232
TVM_REGISTER_GLOBAL("relay._make.Call")
233
.set_body_typed(CallNode::make);
234

235 236
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
237 238 239 240
  auto* node = static_cast<const CallNode*>(ref.get());
  p->stream << "CallNode(" << node->op << ", " << node->args << ", "
            << node->attrs << ", " << node->type_args << ")";
  });
241

242
Let LetNode::make(Var var, Expr value, Expr body) {
243
  ObjectPtr<LetNode> n = make_object<LetNode>();
244 245 246 247 248 249
  n->var = std::move(var);
  n->value = std::move(value);
  n->body = std::move(body);
  return Let(n);
}

250 251
TVM_REGISTER_NODE_TYPE(LetNode);

252
TVM_REGISTER_GLOBAL("relay._make.Let")
253
.set_body_typed(LetNode::make);
254

255 256
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
257
  auto* node = static_cast<const LetNode*>(ref.get());
258
  p->stream << "LetNode(" << node->var << ", " << node->value
259
            << ", " << node->body << ")";
260 261 262
});

If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
263
  ObjectPtr<IfNode> n = make_object<IfNode>();
264 265 266 267 268 269
  n->cond = std::move(cond);
  n->true_branch = std::move(true_branch);
  n->false_branch = std::move(false_branch);
  return If(n);
}

270 271
TVM_REGISTER_NODE_TYPE(IfNode);

272
TVM_REGISTER_GLOBAL("relay._make.If")
273
.set_body_typed(IfNode::make);
274

275 276
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
277
  auto* node = static_cast<const IfNode*>(ref.get());
278
  p->stream << "IfNode(" << node->cond << ", " << node->true_branch
279
            << ", " << node->false_branch << ")";
280 281
});

282
TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
283
  ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
284 285 286 287 288
  n->tuple = std::move(tuple);
  n->index = index;
  return TupleGetItem(n);
}

289 290
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);

291
TVM_REGISTER_GLOBAL("relay._make.TupleGetItem")
292
.set_body_typed(TupleGetItemNode::make);
293

294 295
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
296
  auto* node = static_cast<const TupleGetItemNode*>(ref.get());
297 298 299
  p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

300
RefCreate RefCreateNode::make(Expr value) {
301
  ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>();
302 303 304
  n->value = std::move(value);
  return RefCreate(n);
}
305

306 307
TVM_REGISTER_NODE_TYPE(RefCreateNode);

308
TVM_REGISTER_GLOBAL("relay._make.RefCreate")
309
.set_body_typed(RefCreateNode::make);
310

311 312
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
313
  auto* node = static_cast<const RefCreateNode*>(ref.get());
314 315 316 317
  p->stream << "RefCreateNode(" << node->value << ")";
});

RefRead RefReadNode::make(Expr ref) {
318
  ObjectPtr<RefReadNode> n = make_object<RefReadNode>();
319 320 321 322
  n->ref = std::move(ref);
  return RefRead(n);
}

323 324
TVM_REGISTER_NODE_TYPE(RefReadNode);

325
TVM_REGISTER_GLOBAL("relay._make.RefRead")
326
.set_body_typed(RefReadNode::make);
327

328 329
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
330
  auto* node = static_cast<const RefReadNode*>(ref.get());
331
  p->stream << "RefReadNode(" << node->ref << ")";
332 333
});

334
RefWrite RefWriteNode::make(Expr ref, Expr value) {
335
  ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>();
336 337 338 339 340
  n->ref = std::move(ref);
  n->value = std::move(value);
  return RefWrite(n);
}

341 342
TVM_REGISTER_NODE_TYPE(RefWriteNode);

343
TVM_REGISTER_GLOBAL("relay._make.RefWrite")
344
.set_body_typed(RefWriteNode::make);
345

346 347
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
348
  auto* node = static_cast<const RefWriteNode*>(ref.get());
349 350 351
  p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});

352
TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize")
353
.set_body_typed([](TempExpr temp) {
354
  return temp->Realize();
355
});
356

357
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
358
.set_body_typed(
359
  [](Function func, std::string name, ObjectRef ref) {
360 361 362
    return FunctionSetAttr(func, name, ref);
});

363 364 365
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });

366 367
}  // namespace relay
}  // namespace tvm