expr_functor.cc 12.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 *  Copyright (c) 2019 by Contributors
22 23 24 25 26 27
 * \file src/tvm/relay/expr_mutator.cc
 * \brief A wrapper around ExprFunctor which functionally updates the AST.
 *
 * ExprMutator uses memoization and self return in order to amortize
 * the cost of using functional updates.
 */
雾雨魔理沙 committed
28
#include <tvm/relay/analysis.h>
29
#include <tvm/relay/expr_functor.h>
30
#include <tvm/relay/pattern_functor.h>
31
#include "type_functor.h"
32 33 34 35

namespace tvm {
namespace relay {

36
Expr ExprMutator::VisitExpr(const Expr& expr) {
37 38 39
  auto it = this->memo_.find(expr);
  if (it != this->memo_.end()) {
    return it->second;
40
  } else {
41
    Expr new_expr = ExprFunctor::VisitExpr(expr);
42
    memo_[expr] = new_expr;
43 44 45 46
    return new_expr;
  }
}

47
Expr ExprMutator::VisitExpr_(const VarNode* op) {
48 49 50
  if (op->type_annotation.defined()) {
    auto type = this->VisitType(op->type_annotation);
    if (!op->type_annotation.same_as(type)) {
51
      return VarNode::make(op->vid, type);
52 53 54
    }
  }
  // default case return self.
55
  return GetRef<Expr>(op);
56 57
}

58 59
Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
  return GetRef<Expr>(op);
60 61
}

62 63
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
  return GetRef<Expr>(op);
64 65
}

66 67
Expr ExprMutator::VisitExpr_(const OpNode* op) {
  return GetRef<Expr>(op);
68 69
}

70
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
71 72 73 74 75 76 77 78 79
  tvm::Array<Expr> fields;
  bool all_fields_unchanged = true;
  for (auto field : op->fields) {
    auto new_field = this->Mutate(field);
    fields.push_back(new_field);
    all_fields_unchanged &= new_field.same_as(field);
  }

  if (all_fields_unchanged) {
80
    return GetRef<Expr>(op);
81 82 83 84 85
  } else {
    return TupleNode::make(fields);
  }
}

86
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
87
  tvm::Array<TypeVar> ty_params;
88
  bool all_ty_params_unchanged = true;
89 90

  for (auto ty_param : op->type_params) {
91
    TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
92
    ty_params.push_back(new_ty_param);
93
    all_ty_params_unchanged &= new_ty_param.same_as(ty_param);
94 95
  }

96
  tvm::Array<Var> params;
97
  bool all_params_unchanged = true;
98
  for (auto param : op->params) {
99
    Var new_param = Downcast<Var>(this->Mutate(param));
100
    params.push_back(new_param);
101
    all_params_unchanged &= param.same_as(new_param);
102 103 104 105 106
  }

  auto ret_type = this->VisitType(op->ret_type);
  auto body = this->Mutate(op->body);

107 108
  if (all_ty_params_unchanged &&
      all_params_unchanged &&
109 110 111
      ret_type.same_as(op->ret_type) &&
      body.same_as(op->body)) {
    return GetRef<Expr>(op);
112
  } else {
113
    return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
114 115 116
  }
}

117 118 119
Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
  auto new_op = this->Mutate(call_node->op);
  bool unchanged = call_node->op.same_as(new_op);
120 121 122 123 124

  tvm::Array<Type> ty_args;
  for (auto ty_arg : call_node->type_args) {
    auto new_ty_arg = this->VisitType(ty_arg);
    ty_args.push_back(new_ty_arg);
125
    unchanged &= new_ty_arg.same_as(ty_arg);
126 127 128 129 130 131
  }

  tvm::Array<Expr> call_args;
  for (auto arg : call_node->args) {
    auto new_arg = this->Mutate(arg);
    call_args.push_back(new_arg);
132
    unchanged &= new_arg.same_as(arg);
133 134
  }

135 136
  if (unchanged) {
    return GetRef<Expr>(call_node);
137
  } else {
138
    return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
139 140 141
  }
}

142
Expr ExprMutator::VisitExpr_(const LetNode* op) {
143 144 145 146
  Var var = Downcast<Var>(this->Mutate(op->var));
  auto value = this->Mutate(op->value);
  auto body = this->Mutate(op->body);

147 148 149 150
  if (var.same_as(op->var) &&
      value.same_as(op->value) &&
      body.same_as(op->body)) {
    return GetRef<Expr>(op);
151
  } else {
152
    return LetNode::make(var, value, body);
153 154 155
  }
}

156
Expr ExprMutator::VisitExpr_(const IfNode* op) {
157 158 159
  auto guard = this->Mutate(op->cond);
  auto true_b = this->Mutate(op->true_branch);
  auto false_b = this->Mutate(op->false_branch);
160 161 162 163
  if (op->cond.same_as(guard) &&
      op->true_branch.same_as(true_b) &&
      op->false_branch.same_as(false_b)) {
    return GetRef<Expr>(op);;
164 165 166 167 168
  } else {
    return IfNode::make(guard, true_b, false_b);
  }
}

169 170 171 172 173 174 175
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
  auto t = this->Mutate(g->tuple);
  if (g->tuple == t) {
    return GetRef<Expr>(g);
  } else {
    return TupleGetItemNode::make(t, g->index);
  }
176
}
177

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
  Expr value = this->Mutate(op->value);
  if (value.same_as(op->value)) {
    return GetRef<Expr>(op);
  } else {
    return RefCreateNode::make(value);
  }
}

Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
  Expr ref = this->Mutate(op->ref);
  if (ref.same_as(op->ref)) {
    return GetRef<Expr>(op);
  } else {
    return RefReadNode::make(ref);
  }
}

Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
  Expr ref = this->Mutate(op->ref);
  Expr value = this->Mutate(op->value);
  if (ref.same_as(op->ref) && value.same_as(op->value)) {
    return GetRef<Expr>(op);
  } else {
    return RefWriteNode::make(ref, value);
  }
}

206 207 208 209 210 211 212 213 214
Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
  return GetRef<Expr>(c);
}

Expr ExprMutator::VisitExpr_(const MatchNode* m) {
  std::vector<Clause> clauses;
  for (const Clause& p : m->clauses) {
    clauses.push_back(VisitClause(p));
  }
215
  return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
216 217 218
}

Clause ExprMutator::VisitClause(const Clause& c) {
219 220
  Pattern p = VisitPattern(c->lhs);
  return ClauseNode::make(p, VisitExpr(c->rhs));
221 222 223 224
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }

225 226
Type ExprMutator::VisitType(const Type& t) { return t; }

227
void ExprVisitor::VisitExpr(const Expr& expr) {
228 229 230 231 232 233 234 235
  auto it = visit_counter_.find(expr.get());
  if (it != visit_counter_.end()) {
    ++it->second;
  } else {
    using TParent = ExprFunctor<void(const Expr&)>;
    TParent::VisitExpr(expr);
    visit_counter_.insert({expr.get(), 1});
  }
236 237
}

238
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
239 240 241
  if (op->type_annotation.defined()) {
    this->VisitType(op->type_annotation);
  }
242
}
243

244 245
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
}
246

247 248
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
}
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265

void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
  for (auto field : op->fields) {
    this->VisitExpr(field);
  }
}

void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
  for (auto param : op->params) {
    this->VisitExpr(param);
  }

  this->VisitExpr(op->body);
}

void ExprVisitor::VisitExpr_(const CallNode* op) {
  this->VisitExpr(op->op);
266

267 268 269 270 271 272 273 274 275 276 277
  for (auto ty_arg : op->type_args) {
    this->VisitType(ty_arg);
  }

  for (auto arg : op->args) {
    this->VisitExpr(arg);
  }
}

void ExprVisitor::VisitExpr_(const LetNode* op) {
  this->VisitExpr(op->value);
278
  this->VisitExpr(op->var);
279 280 281 282 283 284 285 286 287 288 289
  this->VisitExpr(op->body);
}

void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}

void ExprVisitor::VisitExpr_(const OpNode* op) { return; }

290 291 292 293
void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
  this->VisitExpr(op->tuple);
}

294 295 296 297 298 299 300 301 302 303 304 305 306
void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
  this->VisitExpr(op->value);
}

void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
  this->VisitExpr(op->ref);
}

void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
  this->VisitExpr(op->ref);
  this->VisitExpr(op->value);
}

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
  for (const Type& t : op->inputs) {
    this->VisitType(t);
  }
  this->VisitType(op->belong_to);
}

void ExprVisitor::VisitExpr_(const MatchNode* op) {
  this->VisitExpr(op->data);
  for (const Clause& c : op->clauses) {
    this->VisitClause(c);
  }
}

void ExprVisitor::VisitClause(const Clause& op) {
  this->VisitPattern(op->lhs);
  this->VisitExpr(op->rhs);
}

void ExprVisitor::VisitPattern(const Pattern& p) { return; }

328 329
void ExprVisitor::VisitType(const Type& t) { return; }

ziheng committed
330 331 332 333
// visitor to implement apply
class ExprApplyVisit : public ExprVisitor {
 public:
  explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
334

ziheng committed
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
  void VisitExpr(const Expr& e) final {
    if (visited_.count(e.get()) != 0) return;
    visited_.insert(e.get());
    ExprVisitor::VisitExpr(e);
    f_(e);
  }

 private:
  std::function<void(const Expr&)> f_;
  std::unordered_set<const Node*> visited_;
};

void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
  ExprApplyVisit(fvisit).VisitExpr(e);
}

Zhi committed
351
TVM_REGISTER_API("relay._analysis.post_order_visit")
352 353
.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
    PostOrderVisit(expr, [f](const Expr& n) {
ziheng committed
354 355 356 357
        f(n);
      });
  });

358
// Implement bind.
359
class ExprBinder : public ExprMutator, PatternMutator {
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
 public:
  explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
    : args_map_(args_map) {
  }

  Expr VisitExpr_(const LetNode* op) final {
    CHECK(!args_map_.count(op->var))
        << "Cannot bind an internel variable in let";
    return ExprMutator::VisitExpr_(op);
  }

  Expr VisitExpr_(const FunctionNode* op) final {
    for (Var param : op->params) {
      CHECK(!args_map_.count(param))
          << "Cannnot bind an internal function parameter";
    }
    return ExprMutator::VisitExpr_(op);
  }

  Expr VisitExpr_(const VarNode* op) final {
    auto id = GetRef<Var>(op);
    auto it = args_map_.find(id);
    if (it != args_map_.end()) {
      return (*it).second;
    } else {
385
      return std::move(id);
386 387 388
    }
  }

389 390 391 392 393 394 395 396 397 398
  Pattern VisitPattern(const Pattern& p) final {
    return PatternMutator::VisitPattern(p);
  }

  Clause VisitClause(const Clause& c) final {
    Pattern pat = VisitPattern(c->lhs);
    return ClauseNode::make(pat, VisitExpr(c->rhs));
  }

  Var VisitVar(const Var& v) final {
399 400 401
    CHECK(!args_map_.count(v))
      << "Cannnot bind an internal pattern variable";
    return v;
402 403
  }

404 405 406 407 408 409
 private:
  const tvm::Map<Var, Expr>& args_map_;
};

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
  if (const FunctionNode* func = expr.as<FunctionNode>()) {
410
    Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
411 412 413 414 415 416 417 418 419 420
    Array<Var> new_params;
    for (Var param : func->params) {
      if (!args_map.count(param)) {
        new_params.push_back(param);
      }
    }
    if (new_body.same_as(func->body) &&
        new_params.size() == func->params.size()) {
      return expr;
    }
雾雨魔理沙 committed
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
    auto ret = FunctionNode::make(new_params,
                                  new_body,
                                  func->ret_type,
                                  func->type_params,
                                  func->attrs);
    std::unordered_set<Var, NodeHash, NodeEqual> set;
    for (const auto& v : FreeVars(expr)) {
      set.insert(v);
    }
    for (const auto& v : FreeVars(ret)) {
      if (set.count(v) == 0) {
        new_params.push_back(v);
      }
    }
    ret = FunctionNode::make(new_params,
                             new_body,
                             func->ret_type,
                             func->type_params,
                             func->attrs);
    CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
441
    return std::move(ret);
442
  } else {
443
    return ExprBinder(args_map).VisitExpr(expr);
444 445 446 447 448 449 450 451 452 453 454 455 456
  }
}

TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    NodeRef input = args[0];
    if (input->derived_from<ExprNode>()) {
      *ret = Bind(Downcast<Expr>(input), args[1]);
    } else {
      CHECK(input->derived_from<TypeNode>());
      *ret = Bind(Downcast<Type>(input), args[1]);
    }
  });
457 458
}  // namespace relay
}  // namespace tvm