expr.cc 10 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/tvm/ir/expr.cc
 * \brief The expression AST nodes of Relay.
 */
#include <tvm/relay/expr.h>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace tvm::runtime;

Constant ConstantNode::make(runtime::NDArray data) {
15
  NodePtr<ConstantNode> n = make_node<ConstantNode>();
16 17 18 19
  n->data = std::move(data);
  return Constant(n);
}

20 21
TVM_REGISTER_NODE_TYPE(ConstantNode);

22
TVM_REGISTER_API("relay._make.Constant")
23
.set_body([](TVMArgs args, TVMRetValue* ret) {
24 25
    *ret = ConstantNode::make(args[0]);
  });
26 27

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
28
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
29 30 31 32
    const PackedFunc* fprint = Registry::Get("relay._constant_repr");
    CHECK(fprint) << "unable to find printing function for constants";
    std::string data = (*fprint)(GetRef<Constant>(node));
    p->stream << "Constant(" << data << ")";
33
  });
34 35 36 37 38

TensorType ConstantNode::tensor_type() const {
  auto dtype = TVMType2Type(data->dtype);
  Array<tvm::Expr> shape;
  for (int i = 0; i < data->ndim; i++) {
39 40 41 42
    CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
    CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
    shape.push_back(
        tvm::ir::IntImm::make(Int(32), data->shape[i]));
43 44 45 46 47 48
  }

  return TensorTypeNode::make(shape, dtype);
}

Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
49
  NodePtr<TupleNode> n = make_node<TupleNode>();
50 51 52 53
  n->fields = std::move(fields);
  return Tuple(n);
}

54 55
TVM_REGISTER_NODE_TYPE(TupleNode);

56
TVM_REGISTER_API("relay._make.Tuple")
57
.set_body([](TVMArgs args, TVMRetValue* ret) {
58 59
    *ret = TupleNode::make(args[0]);
  });
60 61

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
62
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
63 64
    p->stream << "Tuple(" << node->fields << ")";
  });
65

66 67

Var VarNode::make(Id vid, Type type_annotation) {
68
  NodePtr<VarNode> n = make_node<VarNode>();
69
  n->vid = std::move(vid);
70
  n->type_annotation = std::move(type_annotation);
71 72 73
  return Var(n);
}

74 75 76 77 78 79
Var VarNode::make(std::string name_hint, Type type_annotation) {
  NodePtr<IdNode> n = make_node<IdNode>();
  n->name_hint = std::move(name_hint);
  return VarNode::make(Id(n), type_annotation);
}

80 81
TVM_REGISTER_NODE_TYPE(VarNode);

82
TVM_REGISTER_API("relay._make.Var")
83
.set_body([](TVMArgs args, TVMRetValue* ret) {
84
    *ret = VarNode::make(args[0].operator std::string(), args[1]);
85
  });
86 87

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
88
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
89
    p->stream << "Var(" << node->name_hint();
90 91 92 93 94
    if (node->type_annotation.defined()) {
      p->stream << ", ty=";
      p->print(node->type_annotation);
    }
    p->stream << ")";
95
  });
96 97

GlobalVar GlobalVarNode::make(std::string name_hint) {
98
  NodePtr<GlobalVarNode> n = make_node<GlobalVarNode>();
99 100 101 102
  n->name_hint = std::move(name_hint);
  return GlobalVar(n);
}

103 104
TVM_REGISTER_NODE_TYPE(GlobalVarNode);

105
TVM_REGISTER_API("relay._make.GlobalVar")
106
.set_body([](TVMArgs args, TVMRetValue* ret) {
107 108
    *ret = GlobalVarNode::make(args[0]);
  });
109 110

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
111
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
112 113
    p->stream << "GlobalVar(" << node->name_hint << ")";
  });
114 115


116 117
Function FunctionNode::make(tvm::Array<Var> params,
                            Expr body,
118
                            Type ret_type,
119 120
                            tvm::Array<TypeVar> type_params,
                            tvm::Attrs attrs) {
121
  NodePtr<FunctionNode> n = make_node<FunctionNode>();
122 123
  n->params = std::move(params);
  n->body = std::move(body);
124
  n->ret_type = std::move(ret_type);
125
  n->type_params = std::move(type_params);
126
  n->attrs = std::move(attrs);
127 128 129
  return Function(n);
}

130
FuncType FunctionNode::func_type_annotation() const {
131 132
  Array<Type> param_types;
  for (auto param : this->params) {
133 134 135
    Type param_type = (param->type_annotation.defined()) ? param->type_annotation
      : IncompleteTypeNode::make(Kind::kType);
    param_types.push_back(param_type);
136
  }
137 138 139 140

  Type ret_type = (this->ret_type.defined()) ? this->ret_type
    : IncompleteTypeNode::make(Kind::kType);
  return FuncTypeNode::make(param_types, ret_type, this->type_params, {});
141 142
}

143 144 145 146 147 148
bool FunctionNode::IsPrimitive() const {
  NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive");
  const ir::IntImm* pval = res.as<ir::IntImm>();
  return pval && pval->value != 0;
}

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
  if (!func->attrs.defined()) { return NodeRef(); }

  const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
  CHECK(dict_attrs);
  auto it = dict_attrs->dict.find(key);
  if (it != dict_attrs->dict.end()) {
    return (*it).second;
  } else {
    return NodeRef();
  }
}

Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) {
  const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
  Attrs func_attrs;
  if (dattrs) {
    Map<std::string, NodeRef> dict = dattrs->dict;
    dict.Set(key, data);
    func_attrs = DictAttrsNode::make(dict);
  } else {
    Map<std::string, NodeRef> dict = {{key, data}};
    func_attrs = DictAttrsNode::make(dict);
  }

  return FunctionNode::make(
    func->params,
    func->body,
    func->ret_type,
    func->type_params,
    func_attrs);
}

182 183
TVM_REGISTER_NODE_TYPE(FunctionNode);

184
TVM_REGISTER_API("relay._make.Function")
185
.set_body([](TVMArgs args, TVMRetValue* ret) {
186
  *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
187 188 189
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
190 191
.set_dispatch<FunctionNode>([](const FunctionNode* node,
                                   tvm::IRPrinter* p) {
192
      p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
193 194
                << ", " << node->body << ", " << node->type_params << ", "
                << node->attrs << ")";
195 196 197 198
});

Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
                    Array<Type> type_args) {
199
  NodePtr<CallNode> n = make_node<CallNode>();
200 201 202 203 204 205 206
  n->op = std::move(op);
  n->args = std::move(args);
  n->attrs = std::move(attrs);
  n->type_args = std::move(type_args);
  return Call(n);
}

207 208
TVM_REGISTER_NODE_TYPE(CallNode);

209
TVM_REGISTER_API("relay._make.Call")
210
.set_body([](TVMArgs args, TVMRetValue* ret) {
211 212 213 214
  *ret = CallNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
215
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
216 217 218 219
  p->stream << "CallNode(" << node->op << ", " << node->args << ", "
    << node->attrs << ", " << node->type_args << ")";
});

220
Let LetNode::make(Var var, Expr value, Expr body) {
221
  NodePtr<LetNode> n = make_node<LetNode>();
222 223 224 225 226 227
  n->var = std::move(var);
  n->value = std::move(value);
  n->body = std::move(body);
  return Let(n);
}

228 229
TVM_REGISTER_NODE_TYPE(LetNode);

230
TVM_REGISTER_API("relay._make.Let")
231
.set_body([](TVMArgs args, TVMRetValue* ret) {
232 233
    *ret = LetNode::make(args[0], args[1], args[2]);
  });
234 235

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
236
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
237
  p->stream << "LetNode(" << node->var << ", " << node->value
238
            << ", " << node->body << ")";
239 240 241
});

If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
242
  NodePtr<IfNode> n = make_node<IfNode>();
243 244 245 246 247 248
  n->cond = std::move(cond);
  n->true_branch = std::move(true_branch);
  n->false_branch = std::move(false_branch);
  return If(n);
}

249 250
TVM_REGISTER_NODE_TYPE(IfNode);

251
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
252 253 254 255
  *ret = IfNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
256
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
257
  p->stream << "IfNode(" << node->cond << ", " << node->true_branch
258
            << ", " << node->false_branch << ")";
259 260
});

261 262 263 264 265 266 267
TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
  NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
  n->tuple = std::move(tuple);
  n->index = index;
  return TupleGetItem(n);
}

268 269
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);

270 271 272 273 274 275 276 277 278
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = TupleGetItemNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
  p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

279 280 281 282 283
RefCreate RefCreateNode::make(Expr value) {
  NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
  n->value = std::move(value);
  return RefCreate(n);
}
284

285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = RefCreateNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
  p->stream << "RefCreateNode(" << node->value << ")";
});

RefRead RefReadNode::make(Expr ref) {
  NodePtr<RefReadNode> n = make_node<RefReadNode>();
  n->ref = std::move(ref);
  return RefRead(n);
}

TVM_REGISTER_API("relay._make.RefRead")
301
.set_body([](TVMArgs args, TVMRetValue* ret) {
302 303 304 305 306 307
  *ret = RefReadNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
  p->stream << "RefReadNode(" << node->ref << ")";
308 309
});

310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
RefWrite RefWriteNode::make(Expr ref, Expr value) {
  NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
  n->ref = std::move(ref);
  n->value = std::move(value);
  return RefWrite(n);
}

TVM_REGISTER_API("relay._make.RefWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = RefWriteNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
  p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});

TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  TempExpr temp = args[0];
  *ret = temp->Realize();
});
332

333 334
}  // namespace relay
}  // namespace tvm