to_a_normal_form.cc 10.4 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file to_anf.cc
 *
 * \brief Turn implicit sharing into observable sharing.
 */
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
29
#include <tvm/logging.h>
30 31
#include "let_list.h"
#include "../../common/arena.h"
雾雨魔理沙 committed
32
#include "pass_util.h"
33
#include "dependency_graph.h"
34 35 36 37

namespace tvm {
namespace relay {

38 39 40
Expr ToANormalForm(const Expr& e,
                   const Module& m,
                   std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;

/* Invariant: when parent is null level is 0
 *
 * Invariant: when parent is not null level is 1 + parent->level
 */
struct ScopeNode {
  size_t level;
  Scope parent;
  std::shared_ptr<LetList> ll = std::make_shared<LetList>();
  explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { }
  ScopeNode() : level(0) { }
};

Scope ChildScope(const Scope& s) {
  return std::make_shared<ScopeNode>(s);
}

Scope LCA(Scope lhs, Scope rhs) {
  while (lhs != rhs) {
    if (lhs->level > rhs->level) {
      lhs = lhs->parent;
    } else if (lhs->level < rhs->level) {
      rhs = rhs->parent;
    } else {
      lhs = lhs->parent;
      rhs = rhs->parent;
    }
  }
  return lhs;
}

std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGraph& dg) {
  std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
  Scope global_scope = std::make_shared<ScopeNode>();
  for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
    DependencyGraph::Node* n = *it;
80
    auto iit = n->parents.head;
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    Scope s;
    if (iit == nullptr) {
      s = global_scope;
    } else {
      s = expr_scope.at(iit->value);
      iit = iit->next;
      for (; iit != nullptr; iit = iit->next) {
        s = LCA(s, expr_scope.at(iit->value));
      }
    }
    expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
  }
  return expr_scope;
}

bool IsPrimitiveFunction(const Expr& e) {
  return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive();
}

100 101 102 103
/* Special care is needed to handle local recursion.
 * Fill additionally take a (possibly null) Var argument,
 * If it is not null, Fill is required to bind the transformed result to that var.
 */
104 105
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
 public:
雾雨魔理沙 committed
106 107 108 109
  static Expr ToANormalForm(const Expr& e,
                            const Module& m,
                            const DependencyGraph& dg,
                            std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
110
                            std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
111 112 113 114 115 116 117 118
    Fill fi(m, dg, node_scope, gv);
    return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
  }

 private:
  Module mod_;
  const DependencyGraph& dg_;
  std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
119
  std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
120 121 122 123 124
  std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;

  Fill(Module mod,
       const DependencyGraph& dg,
       std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
125
       std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
126 127 128 129 130 131 132 133 134 135 136
    mod_(mod),
    dg_(dg),
    node_scope_(node_scope),
    visited_(visited) { }

  Scope GetScope(const Expr& e) {
    return node_scope_->at(dg_.expr_node.at(e));
  }

  Scope GetSubScope(const Expr& e, size_t i) {
    DependencyGraph::Node* n = dg_.expr_node.at(e);
137
    auto h = n->children.head;
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    while (i != 0) {
      CHECK(h);
      --i;
      h = h->next;
    }
    CHECK(h);
    return node_scope_->at(h->value);
  }

  Expr VisitExpr(const Expr& e, const Var& v) final {
    if (memo.count(e) == 0) {
      memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
    }
    return memo.at(e);
  }

  Expr VisitExpr(const Expr& e) {
155 156 157 158 159
    return this->VisitExpr(e, Var());
  }

  Expr Atomic(const Expr& orig, const Expr& now, const Var& v) {
    return v.defined() ? GetScope(orig)->ll->Push(v, now) : now;
160 161 162
  }

  Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
163 164 165 166
    Var var = v.defined() ?
      v :
      VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
    return GetScope(orig)->ll->Push(var, now);
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
  }

  Expr VisitExpr_(const CallNode* c, const Var& v) final {
    Expr e = GetRef<Expr>(c);
    std::vector<Expr> args;
    for (const auto& a : c->args) {
      args.push_back(VisitExpr(a));
    }
    return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v);
  }

  Expr VisitExpr_(const TupleNode* t, const Var& v) final {
    Expr e = GetRef<Expr>(t);
    std::vector<Expr> fields;
    for (const auto& a : t->fields) {
      fields.push_back(VisitExpr(a));
    }
    return Compound(e, TupleNode::make(fields), v);
  }

  Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
    Expr e = GetRef<Expr>(t);
    return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v);
  }

192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
    Expr e = GetRef<Expr>(r);
    return Compound(e, RefCreateNode::make(VisitExpr(r->value)), v);
  }

  Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
    Expr e = GetRef<Expr>(r);
    return Compound(e, RefReadNode::make(VisitExpr(r->ref)), v);
  }

  Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
    Expr e = GetRef<Expr>(r);
    return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), v);
  }

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
  Expr VisitExpr_(const IfNode* i, const Var& v) final {
    Expr e = GetRef<Expr>(i);
    Expr ret = IfNode::make(VisitExpr(i->cond),
                            GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
                            GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
    return Compound(e, ret, v);
  }

  Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
    Expr e = GetRef<Expr>(f);
    Expr ret;
    if (IsPrimitiveFunction(e)) {
      ret = e;
    } else {
      ret = FunctionNode::make(f->params,
                               GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)),
                               f->ret_type,
                               f->type_params,
                               f->attrs);
    }
    return Compound(e, ret, v);
  }

  Expr VisitExpr_(const LetNode* l, const Var& v) final {
    Expr e = GetRef<Expr>(l);
    VisitExpr(l->value, l->var);
    Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body));
    return Compound(e, ret, v);
  }

  Expr VisitExpr_(const ConstantNode* c, const Var& v) final {
    Expr e = GetRef<Expr>(c);
    return Compound(e, e, v);
  }

  Expr VisitExpr_(const VarNode* vn, const Var& v) final {
243 244
    Expr e = GetRef<Expr>(vn);
    return Atomic(e, e, v);
245 246 247 248 249 250
  }

  Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
    GlobalVar gv = GetRef<GlobalVar>(gvn);
    if (visited_->count(gv) == 0) {
      visited_->insert(gv);
雾雨魔理沙 committed
251
      mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
252
    }
253
    return Atomic(gv, gv, v);
254 255 256
  }

  Expr VisitExpr_(const OpNode* op, const Var& v) final {
257 258
    Expr e = GetRef<Expr>(op);
    return Atomic(e, e, v);
259
  }
260 261

  Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
262 263
    Expr e = GetRef<Expr>(c);
    return Atomic(e, e, v);
264 265 266 267 268 269 270 271 272 273 274
  }

  Expr VisitExpr_(const MatchNode* m, const Var& v) final {
    Expr e = GetRef<Expr>(m);
    Expr data = VisitExpr(m->data);
    std::vector<Clause> clauses;
    for (const Clause& c : m->clauses) {
      clauses.push_back(ClauseNode::make(
        c->lhs,
        GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
    }
275
    return Compound(e, MatchNode::make(data, clauses), v);
276
  }
277 278
};

279 280 281
Expr ToANormalFormAux(const Expr& e,
                      const Module& m,
                      std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
  /* When you lift a lambda, what is inside is also being lift.
   *
   * So we must determine the scope of the lambda before determining the scope of it's body.
   *
   * To make this more principled,
   * we always determine the scope of parent before determining the scope of children.
   *
   * So we calculate all the dependency between nodes.
   */
  common::Arena arena;
  DependencyGraph dg = DependencyGraph::Create(&arena, e);
  /* In order to model new subscopes created by lambda, if else and pattern matching,
   * we also assign scope to edge as well.
   * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
   *
   * So, the scope of the whole expr is global.
   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
   *
   * Every scope additionally contain a LetList which collect all value of that scope.
   * We do an additional pass to fill all the LetList and we are done.
   */
  std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
雾雨魔理沙 committed
304
  return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
305 306
}

307 308 309
Expr ToANormalForm(const Expr& e,
                   const Module& m,
                   std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
  DLOG(INFO)
  << "ToANF:" << std::endl
  << AsText(e, false);

  Expr ret =
    TransformF([&](const Expr& e) {
      return ToANormalFormAux(e, m, gv);
    }, e);

  CHECK_EQ(FreeVars(ret).size(), 0);

  DLOG(INFO)
    << "ToANF: transformed" << std::endl
    << AsText(ret, false);

  return ret;
326 327
}

雾雨魔理沙 committed
328
Expr ToANormalForm(const Expr& e, const Module& m) {
329
  std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
雾雨魔理沙 committed
330
  return ToANormalForm(e, m, &gv);
331 332
}

雾雨魔理沙 committed
333
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
334
.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
335 336 337

}  // namespace relay
}  // namespace tvm