expr_functor.cc 12.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file src/tvm/relay/expr_functor.cc
22 23 24 25 26
 * \brief A wrapper around ExprFunctor which functionally updates the AST.
 *
 * ExprMutator uses memoization and self return in order to amortize
 * the cost of using functional updates.
 */
27
#include <tvm/ir/type_functor.h>
雾雨魔理沙 committed
28
#include <tvm/relay/analysis.h>
29
#include <tvm/relay/expr_functor.h>
30
#include <tvm/relay/pattern_functor.h>
31 32 33 34

namespace tvm {
namespace relay {

35
Expr ExprMutator::VisitExpr(const Expr& expr) {
36 37 38
  auto it = this->memo_.find(expr);
  if (it != this->memo_.end()) {
    return it->second;
39
  } else {
40
    Expr new_expr = ExprFunctor::VisitExpr(expr);
41
    memo_[expr] = new_expr;
42 43 44 45
    return new_expr;
  }
}

46
Expr ExprMutator::VisitExpr_(const VarNode* op) {
47 48 49
  if (op->type_annotation.defined()) {
    auto type = this->VisitType(op->type_annotation);
    if (!op->type_annotation.same_as(type)) {
50
      return VarNode::make(op->vid, type);
51 52 53
    }
  }
  // default case return self.
54
  return GetRef<Expr>(op);
55 56
}

57 58
Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
  return GetRef<Expr>(op);
59 60
}

61 62
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
  return GetRef<Expr>(op);
63 64
}

65 66
Expr ExprMutator::VisitExpr_(const OpNode* op) {
  return GetRef<Expr>(op);
67 68
}

69
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
70 71 72 73 74 75 76 77 78
  tvm::Array<Expr> fields;
  bool all_fields_unchanged = true;
  for (auto field : op->fields) {
    auto new_field = this->Mutate(field);
    fields.push_back(new_field);
    all_fields_unchanged &= new_field.same_as(field);
  }

  if (all_fields_unchanged) {
79
    return GetRef<Expr>(op);
80 81 82 83 84
  } else {
    return TupleNode::make(fields);
  }
}

85
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
86
  tvm::Array<TypeVar> ty_params;
87
  bool all_ty_params_unchanged = true;
88 89

  for (auto ty_param : op->type_params) {
90
    TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
91
    ty_params.push_back(new_ty_param);
92
    all_ty_params_unchanged &= new_ty_param.same_as(ty_param);
93 94
  }

95
  tvm::Array<Var> params;
96
  bool all_params_unchanged = true;
97
  for (auto param : op->params) {
98
    Var new_param = Downcast<Var>(this->Mutate(param));
99
    params.push_back(new_param);
100
    all_params_unchanged &= param.same_as(new_param);
101 102 103 104 105
  }

  auto ret_type = this->VisitType(op->ret_type);
  auto body = this->Mutate(op->body);

106 107
  if (all_ty_params_unchanged &&
      all_params_unchanged &&
108 109 110
      ret_type.same_as(op->ret_type) &&
      body.same_as(op->body)) {
    return GetRef<Expr>(op);
111
  } else {
112
    return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
113 114 115
  }
}

116 117 118
Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
  auto new_op = this->Mutate(call_node->op);
  bool unchanged = call_node->op.same_as(new_op);
119 120 121 122 123

  tvm::Array<Type> ty_args;
  for (auto ty_arg : call_node->type_args) {
    auto new_ty_arg = this->VisitType(ty_arg);
    ty_args.push_back(new_ty_arg);
124
    unchanged &= new_ty_arg.same_as(ty_arg);
125 126 127 128 129 130
  }

  tvm::Array<Expr> call_args;
  for (auto arg : call_node->args) {
    auto new_arg = this->Mutate(arg);
    call_args.push_back(new_arg);
131
    unchanged &= new_arg.same_as(arg);
132 133
  }

134 135
  if (unchanged) {
    return GetRef<Expr>(call_node);
136
  } else {
137
    return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
138 139 140
  }
}

141
Expr ExprMutator::VisitExpr_(const LetNode* op) {
142 143 144 145
  Var var = Downcast<Var>(this->Mutate(op->var));
  auto value = this->Mutate(op->value);
  auto body = this->Mutate(op->body);

146 147 148 149
  if (var.same_as(op->var) &&
      value.same_as(op->value) &&
      body.same_as(op->body)) {
    return GetRef<Expr>(op);
150
  } else {
151
    return LetNode::make(var, value, body);
152 153 154
  }
}

155
Expr ExprMutator::VisitExpr_(const IfNode* op) {
156 157 158
  auto guard = this->Mutate(op->cond);
  auto true_b = this->Mutate(op->true_branch);
  auto false_b = this->Mutate(op->false_branch);
159 160 161 162
  if (op->cond.same_as(guard) &&
      op->true_branch.same_as(true_b) &&
      op->false_branch.same_as(false_b)) {
    return GetRef<Expr>(op);;
163 164 165 166 167
  } else {
    return IfNode::make(guard, true_b, false_b);
  }
}

168 169 170 171 172 173 174
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
  auto t = this->Mutate(g->tuple);
  if (g->tuple == t) {
    return GetRef<Expr>(g);
  } else {
    return TupleGetItemNode::make(t, g->index);
  }
175
}
176

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
  Expr value = this->Mutate(op->value);
  if (value.same_as(op->value)) {
    return GetRef<Expr>(op);
  } else {
    return RefCreateNode::make(value);
  }
}

Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
  Expr ref = this->Mutate(op->ref);
  if (ref.same_as(op->ref)) {
    return GetRef<Expr>(op);
  } else {
    return RefReadNode::make(ref);
  }
}

Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
  Expr ref = this->Mutate(op->ref);
  Expr value = this->Mutate(op->value);
  if (ref.same_as(op->ref) && value.same_as(op->value)) {
    return GetRef<Expr>(op);
  } else {
    return RefWriteNode::make(ref, value);
  }
}

205 206 207 208 209 210 211 212 213
Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
  return GetRef<Expr>(c);
}

Expr ExprMutator::VisitExpr_(const MatchNode* m) {
  std::vector<Clause> clauses;
  for (const Clause& p : m->clauses) {
    clauses.push_back(VisitClause(p));
  }
214
  return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
215 216 217
}

Clause ExprMutator::VisitClause(const Clause& c) {
218 219
  Pattern p = VisitPattern(c->lhs);
  return ClauseNode::make(p, VisitExpr(c->rhs));
220 221 222 223
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }

224 225
Type ExprMutator::VisitType(const Type& t) { return t; }

226
void ExprVisitor::VisitExpr(const Expr& expr) {
227 228 229 230 231 232 233 234
  auto it = visit_counter_.find(expr.get());
  if (it != visit_counter_.end()) {
    ++it->second;
  } else {
    using TParent = ExprFunctor<void(const Expr&)>;
    TParent::VisitExpr(expr);
    visit_counter_.insert({expr.get(), 1});
  }
235 236
}

237
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
238 239 240
  if (op->type_annotation.defined()) {
    this->VisitType(op->type_annotation);
  }
241
}
242

243 244
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
}
245

246 247
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
}
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
  for (auto field : op->fields) {
    this->VisitExpr(field);
  }
}

void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
  for (auto param : op->params) {
    this->VisitExpr(param);
  }

  this->VisitExpr(op->body);
}

void ExprVisitor::VisitExpr_(const CallNode* op) {
  this->VisitExpr(op->op);
265

266 267 268 269 270 271 272 273 274 275 276
  for (auto ty_arg : op->type_args) {
    this->VisitType(ty_arg);
  }

  for (auto arg : op->args) {
    this->VisitExpr(arg);
  }
}

void ExprVisitor::VisitExpr_(const LetNode* op) {
  this->VisitExpr(op->value);
277
  this->VisitExpr(op->var);
278 279 280 281 282 283 284 285 286 287 288
  this->VisitExpr(op->body);
}

void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}

void ExprVisitor::VisitExpr_(const OpNode* op) { return; }

289 290 291 292
void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
  this->VisitExpr(op->tuple);
}

293 294 295 296 297 298 299 300 301 302 303 304 305
void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
  this->VisitExpr(op->value);
}

void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
  this->VisitExpr(op->ref);
}

void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
  this->VisitExpr(op->ref);
  this->VisitExpr(op->value);
}

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
  for (const Type& t : op->inputs) {
    this->VisitType(t);
  }
  this->VisitType(op->belong_to);
}

void ExprVisitor::VisitExpr_(const MatchNode* op) {
  this->VisitExpr(op->data);
  for (const Clause& c : op->clauses) {
    this->VisitClause(c);
  }
}

void ExprVisitor::VisitClause(const Clause& op) {
  this->VisitPattern(op->lhs);
  this->VisitExpr(op->rhs);
}

void ExprVisitor::VisitPattern(const Pattern& p) { return; }

327 328
void ExprVisitor::VisitType(const Type& t) { return; }

ziheng committed
329 330 331 332
// visitor to implement apply
class ExprApplyVisit : public ExprVisitor {
 public:
  explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
333

ziheng committed
334 335 336 337 338 339 340 341 342
  void VisitExpr(const Expr& e) final {
    if (visited_.count(e.get()) != 0) return;
    visited_.insert(e.get());
    ExprVisitor::VisitExpr(e);
    f_(e);
  }

 private:
  std::function<void(const Expr&)> f_;
343
  std::unordered_set<const Object*> visited_;
ziheng committed
344 345 346 347 348 349
};

void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
  ExprApplyVisit(fvisit).VisitExpr(e);
}

350
TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit")
351
.set_body_typed([](Expr expr, PackedFunc f) {
352
    PostOrderVisit(expr, [f](const Expr& n) {
ziheng committed
353 354 355 356
        f(n);
      });
  });

357
// Implement bind.
358
class ExprBinder : public ExprMutator, PatternMutator {
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
 public:
  explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
    : args_map_(args_map) {
  }

  Expr VisitExpr_(const LetNode* op) final {
    CHECK(!args_map_.count(op->var))
        << "Cannot bind an internel variable in let";
    return ExprMutator::VisitExpr_(op);
  }

  Expr VisitExpr_(const FunctionNode* op) final {
    for (Var param : op->params) {
      CHECK(!args_map_.count(param))
          << "Cannnot bind an internal function parameter";
    }
    return ExprMutator::VisitExpr_(op);
  }

  Expr VisitExpr_(const VarNode* op) final {
    auto id = GetRef<Var>(op);
    auto it = args_map_.find(id);
    if (it != args_map_.end()) {
      return (*it).second;
    } else {
384
      return std::move(id);
385 386 387
    }
  }

388 389 390 391 392 393 394 395 396 397
  Pattern VisitPattern(const Pattern& p) final {
    return PatternMutator::VisitPattern(p);
  }

  Clause VisitClause(const Clause& c) final {
    Pattern pat = VisitPattern(c->lhs);
    return ClauseNode::make(pat, VisitExpr(c->rhs));
  }

  Var VisitVar(const Var& v) final {
398 399 400
    CHECK(!args_map_.count(v))
      << "Cannnot bind an internal pattern variable";
    return v;
401 402
  }

403 404 405 406 407 408
 private:
  const tvm::Map<Var, Expr>& args_map_;
};

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
  if (const FunctionNode* func = expr.as<FunctionNode>()) {
409
    Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
410 411 412 413 414 415 416 417 418 419
    Array<Var> new_params;
    for (Var param : func->params) {
      if (!args_map.count(param)) {
        new_params.push_back(param);
      }
    }
    if (new_body.same_as(func->body) &&
        new_params.size() == func->params.size()) {
      return expr;
    }
雾雨魔理沙 committed
420 421 422 423 424
    auto ret = FunctionNode::make(new_params,
                                  new_body,
                                  func->ret_type,
                                  func->type_params,
                                  func->attrs);
425
    std::unordered_set<Var, ObjectHash, ObjectEqual> set;
雾雨魔理沙 committed
426 427 428 429 430 431 432 433 434 435 436 437 438 439
    for (const auto& v : FreeVars(expr)) {
      set.insert(v);
    }
    for (const auto& v : FreeVars(ret)) {
      if (set.count(v) == 0) {
        new_params.push_back(v);
      }
    }
    ret = FunctionNode::make(new_params,
                             new_body,
                             func->ret_type,
                             func->type_params,
                             func->attrs);
    CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
440
    return std::move(ret);
441
  } else {
442
    return ExprBinder(args_map).VisitExpr(expr);
443 444 445
  }
}

446
TVM_REGISTER_GLOBAL("relay._expr.Bind")
447
.set_body([](TVMArgs args, TVMRetValue* ret) {
448
    ObjectRef input = args[0];
449
    if (input->IsInstance<ExprNode>()) {
450 451
      *ret = Bind(Downcast<Expr>(input), args[1]);
    } else {
452
      CHECK(input->IsInstance<TypeNode>());
453 454 455
      *ret = Bind(Downcast<Type>(input), args[1]);
    }
  });
456 457
}  // namespace relay
}  // namespace tvm