to_a_normal_form.cc 10.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 *
22
 * \file to_a_normal_form.cc
23 24 25
 *
 * \brief Turn implicit sharing into observable sharing.
 */
Zhi committed
26
#include <tvm/relay/analysis.h>
27
#include <tvm/relay/expr_functor.h>
28 29
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
30
#include <tvm/support/logging.h>
31
#include "let_list.h"
雾雨魔理沙 committed
32
#include "pass_util.h"
33 34
#include "../../support/arena.h"
#include "../analysis/dependency_graph.h"
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

namespace tvm {
namespace relay {

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;

/* Invariant: when parent is null level is 0
 *
 * Invariant: when parent is not null level is 1 + parent->level
 */
struct ScopeNode {
  size_t level;
  Scope parent;
  std::shared_ptr<LetList> ll = std::make_shared<LetList>();
  explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { }
  ScopeNode() : level(0) { }
};

Scope ChildScope(const Scope& s) {
  return std::make_shared<ScopeNode>(s);
}

Scope LCA(Scope lhs, Scope rhs) {
  while (lhs != rhs) {
    if (lhs->level > rhs->level) {
      lhs = lhs->parent;
    } else if (lhs->level < rhs->level) {
      rhs = rhs->parent;
    } else {
      lhs = lhs->parent;
      rhs = rhs->parent;
    }
  }
  return lhs;
}

std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGraph& dg) {
  std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
74
  bool global_scope_used = false;
75 76 77
  Scope global_scope = std::make_shared<ScopeNode>();
  for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
    DependencyGraph::Node* n = *it;
78
    auto iit = n->parents.head;
79 80
    Scope s;
    if (iit == nullptr) {
81
      CHECK(!global_scope_used);
82
      s = global_scope;
83
      global_scope_used = true;
84 85 86 87 88 89 90 91 92
    } else {
      s = expr_scope.at(iit->value);
      iit = iit->next;
      for (; iit != nullptr; iit = iit->next) {
        s = LCA(s, expr_scope.at(iit->value));
      }
    }
    expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
  }
93
  CHECK(global_scope_used);
94 95 96
  return expr_scope;
}

97 98 99 100
/* Special care is needed to handle local recursion.
 * Fill additionally take a (possibly null) Var argument,
 * If it is not null, Fill is required to bind the transformed result to that var.
 */
101 102
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
 public:
雾雨魔理沙 committed
103 104
  static Expr ToANormalForm(const Expr& e,
                            const DependencyGraph& dg,
105 106
                            std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) {
    Fill fi(dg, node_scope);
107 108 109 110 111 112
    return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
  }

 private:
  const DependencyGraph& dg_;
  std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
113
  std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo;
114

115 116
  Fill(const DependencyGraph& dg,
       std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) :
117
    dg_(dg),
118
    node_scope_(node_scope) { }
119 120 121 122 123 124 125

  Scope GetScope(const Expr& e) {
    return node_scope_->at(dg_.expr_node.at(e));
  }

  Scope GetSubScope(const Expr& e, size_t i) {
    DependencyGraph::Node* n = dg_.expr_node.at(e);
126
    auto h = n->children.head;
127 128 129 130 131 132 133 134 135 136 137 138
    while (i != 0) {
      CHECK(h);
      --i;
      h = h->next;
    }
    CHECK(h);
    return node_scope_->at(h->value);
  }

  Expr VisitExpr(const Expr& e, const Var& v) final {
    if (memo.count(e) == 0) {
      memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
139 140
    } else if (v.defined()) {
      GetScope(e)->ll->Push(v, memo.at(e));
141
    }
142 143 144
    auto ret = memo.at(e);
    CHECK(IsAtomic(ret));
    return ret;
145 146 147
  }

  Expr VisitExpr(const Expr& e) {
148 149 150
    return this->VisitExpr(e, Var());
  }

151 152
  Expr Atomic(const Expr& e, const Var& v) {
    return v.defined() ? GetScope(e)->ll->Push(v, e) : e;
153 154 155
  }

  Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
156 157
    Var var = v.defined() ?
      v :
158
      Var(std::string("x"), Type());
159
    return GetScope(orig)->ll->Push(var, now);
160 161 162 163 164 165 166 167
  }

  Expr VisitExpr_(const CallNode* c, const Var& v) final {
    Expr e = GetRef<Expr>(c);
    std::vector<Expr> args;
    for (const auto& a : c->args) {
      args.push_back(VisitExpr(a));
    }
168
    return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
169 170 171 172 173 174 175 176
  }

  Expr VisitExpr_(const TupleNode* t, const Var& v) final {
    Expr e = GetRef<Expr>(t);
    std::vector<Expr> fields;
    for (const auto& a : t->fields) {
      fields.push_back(VisitExpr(a));
    }
177
    return Compound(e, Tuple(fields), v);
178 179 180 181
  }

  Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
    Expr e = GetRef<Expr>(t);
182
    return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
183 184
  }

185 186
  Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
    Expr e = GetRef<Expr>(r);
187
    return Compound(e, RefCreate(VisitExpr(r->value)), v);
188 189 190 191
  }

  Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
    Expr e = GetRef<Expr>(r);
192
    return Compound(e, RefRead(VisitExpr(r->ref)), v);
193 194 195 196
  }

  Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
    Expr e = GetRef<Expr>(r);
197
    return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
198 199
  }

200 201
  Expr VisitExpr_(const IfNode* i, const Var& v) final {
    Expr e = GetRef<Expr>(i);
202
    Expr ret = If(VisitExpr(i->cond),
203 204 205 206 207 208 209 210
                            GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
                            GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
    return Compound(e, ret, v);
  }

  Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
    Expr e = GetRef<Expr>(f);
    Expr ret;
211
    if (f->HasNonzeroAttr(attr::kPrimitive)) {
212 213
      ret = e;
    } else {
214
      ret = Function(f->params,
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
                               GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)),
                               f->ret_type,
                               f->type_params,
                               f->attrs);
    }
    return Compound(e, ret, v);
  }

  Expr VisitExpr_(const LetNode* l, const Var& v) final {
    Expr e = GetRef<Expr>(l);
    VisitExpr(l->value, l->var);
    Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body));
    return Compound(e, ret, v);
  }

  Expr VisitExpr_(const ConstantNode* c, const Var& v) final {
    Expr e = GetRef<Expr>(c);
    return Compound(e, e, v);
  }

  Expr VisitExpr_(const VarNode* vn, const Var& v) final {
236
    Expr e = GetRef<Expr>(vn);
237
    return Atomic(e, v);
238 239 240 241
  }

  Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
    GlobalVar gv = GetRef<GlobalVar>(gvn);
242
    return Atomic(gv, v);
243 244 245
  }

  Expr VisitExpr_(const OpNode* op, const Var& v) final {
246
    Expr e = GetRef<Expr>(op);
247
    return Atomic(e, v);
248
  }
249 250

  Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
251
    Expr e = GetRef<Expr>(c);
252
    return Atomic(e, v);
253 254 255 256 257 258 259
  }

  Expr VisitExpr_(const MatchNode* m, const Var& v) final {
    Expr e = GetRef<Expr>(m);
    Expr data = VisitExpr(m->data);
    std::vector<Clause> clauses;
    for (const Clause& c : m->clauses) {
260
      clauses.push_back(Clause(
261 262 263
        c->lhs,
        GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
    }
264
    return Compound(e, Match(data, clauses, m->complete), v);
265
  }
266 267
};

268
Expr ToANormalFormAux(const Expr& e) {
269 270 271 272 273 274 275 276 277
  /* When you lift a lambda, what is inside is also being lift.
   *
   * So we must determine the scope of the lambda before determining the scope of it's body.
   *
   * To make this more principled,
   * we always determine the scope of parent before determining the scope of children.
   *
   * So we calculate all the dependency between nodes.
   */
278
  support::Arena arena;
279 280 281 282 283 284 285 286 287 288 289 290
  DependencyGraph dg = DependencyGraph::Create(&arena, e);
  /* In order to model new subscopes created by lambda, if else and pattern matching,
   * we also assign scope to edge as well.
   * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
   *
   * So, the scope of the whole expr is global.
   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
   *
   * Every scope additionally contain a LetList which collect all value of that scope.
   * We do an additional pass to fill all the LetList and we are done.
   */
  std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
291
  return Fill::ToANormalForm(e, dg, &node_scope);
292 293
}

294
IRModule ToANormalForm(const IRModule& m) {
295 296 297 298 299
  DLOG(INFO) << "ToANF:" << std::endl << m;

  tvm::Map<GlobalVar, Function> updates;
  auto funcs = m->functions;
  for (const auto& it : funcs) {
300
    CHECK_EQ(FreeVars(it.second).size(), 0);
301
    if (const auto* n = it.second.as<FunctionNode>()) {
302
      if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
303
    }
304 305 306 307
    Expr ret =
      TransformF([&](const Expr& e) {
        return ToANormalFormAux(e);
      }, it.second);
308 309 310 311
    CHECK_EQ(FreeVars(ret).size(), 0)
      << AsText(ret)
      << "should not has free vars: "
      << FreeVars(ret);
312 313
    updates.Set(it.first, Downcast<Function>(ret));
  }
314

315 316 317
  for (auto pair : updates) {
    m->Add(pair.first, pair.second, true);
  }
318

319
  DLOG(INFO) << "ToANF: transformed" << std::endl << m;
320

321
  return m;
322 323
}

324 325 326
namespace transform {

Pass ToANormalForm() {
327 328 329
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
    [=](IRModule m, PassContext pc) {
    return relay::ToANormalForm(m);
330
  };
331
  return CreateModulePass(pass_func, 1, "ToANormalForm", {});
332 333
}

334
TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm")
335 336
.set_body_typed(ToANormalForm);

337 338
}  // namespace transform

339 340
}  // namespace relay
}  // namespace tvm