transform.cc 15.2 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

Zhi committed
20
/*!
21 22
 * \file src/ir/transform.cc
 * \brief Infrastructure for transformation passes.
Zhi committed
23
 */
24
#include <dmlc/thread_local.h>
25
#include <tvm/runtime/registry.h>
26
#include <tvm/runtime/device_api.h>
27
#include <tvm/node/repr_printer.h>
28 29 30
#include <tvm/ir/transform.h>

// TODO(tqchen): Update to use String container after it is merged.
31
#include <tvm/tir/expr.h>
32 33 34

#include <stack>
#include <unordered_set>
Zhi committed
35 36

namespace tvm {
37
namespace transform {
Zhi committed
38

39 40
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
41
using tvm::ReprPrinter;
Zhi committed
42

43
struct PassContextThreadLocalEntry {
44 45 46 47 48 49
  /*! \brief The default pass context. */
  PassContext default_context;

  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;

50
  PassContextThreadLocalEntry() {
51
    default_context = PassContext(make_object<PassContextNode>());
52 53 54 55
  }
};

/*! \brief Thread local store to hold the pass context. */
56
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
57 58 59
    RelayPassContextThreadLocalStore;

void PassContext::EnterWithScope() {
60
  PassContextThreadLocalEntry* entry =
61 62 63 64 65
      RelayPassContextThreadLocalStore::Get();
  entry->context_stack.push(*this);
}

void PassContext::ExitWithScope() {
66
  PassContextThreadLocalEntry* entry =
67 68 69 70 71 72 73
      RelayPassContextThreadLocalStore::Get();
  CHECK(!entry->context_stack.empty());
  CHECK(entry->context_stack.top().same_as(*this));
  entry->context_stack.pop();
}

PassContext PassContext::Current() {
74
  PassContextThreadLocalEntry* entry =
75 76 77 78 79 80 81 82
      RelayPassContextThreadLocalStore::Get();
  if (!entry->context_stack.empty()) {
    return entry->context_stack.top();
  } else {
    return entry->default_context;
  }
}

83
PassContext PassContext::Create() {
84
  return PassContext(make_object<PassContextNode>());
85 86
}

87 88 89 90 91 92 93
void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
    auto pass_ctx_node = this->operator->();
    if (pass_ctx_node->trace_func != nullptr) {
      pass_ctx_node->trace_func(module, info, is_before);
    }
}

Zhi committed
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
class ModulePass;

/*!
 * \brief Module-level passes are designed to implement global
 * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
 * at this level have the full control of a given Relay program including
 * addition and deletion of functions.
 */
class ModulePassNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;

  /*! \brief The pass function sketches the real optimization. For example,
   * we may need to perform dead code elimination on the module level. We could
   * implement the algorithm in the `pass_func` and let it run on a module. It
   * will then remove the dead code including the unused functions in the module.
   */
112
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func;
Zhi committed
113 114 115

  ModulePassNode() = default;

116
  void VisitAttrs(tvm::AttrVisitor* v) {
Zhi committed
117 118 119 120
    v->Visit("pass_info", &pass_info);
  }

  /*!
121
   * \brief Run a module pass on given pass context.
Zhi committed
122
   *
123 124
   * \param mod The module that an optimization pass is applied on.
   * \param mod The context that an optimization pass executes on.
Zhi committed
125 126 127
   *
   * \return Return the updated module.
   */
128
  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
Zhi committed
129 130 131 132

  /*!
   * \brief Get the pass information/meta data.
   */
133
  PassInfo Info() const override { return pass_info; }
Zhi committed
134 135

  static constexpr const char* _type_key = "relay.ModulePass";
136
  TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
Zhi committed
137 138
};

139 140
class ModulePass : public Pass {
 public:
141 142
  ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
             PassInfo pass_info);
Zhi committed
143

144
  TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
145
};
Zhi committed
146 147

/*!
148
 * \brief The SequentialNode contains a set of passes that transform Relay
Zhi committed
149 150 151 152 153 154
 * programs from one AST to another semantically equivalent one.
 *
 * One example of this level of pass is that the pass manager needs to correctly
 * perform a host of optimizations with a given optimization level and disabled
 * passes.
 */
155
class SequentialNode : public PassNode {
Zhi committed
156 157 158 159
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;

160 161
  /*! \brief A list of passes that used to compose a sequential pass. */
  tvm::Array<Pass> passes;
162

163
  void VisitAttrs(tvm::AttrVisitor* v) {
Zhi committed
164 165 166 167 168 169 170
    v->Visit("pass_info", &pass_info);
    v->Visit("passes", &passes);
  }

  /*!
   * \brief Get the pass information/meta data.
   */
171
  PassInfo Info() const override { return pass_info; }
Zhi committed
172 173

  /*!
174 175
   * \brief Check if a pass is enabled.
   *
176
   * \param info The pass information.
177 178 179
   *
   * \return true if the pass is enabled. Otherwise, false.
   */
180
  bool PassEnabled(const PassInfo& info) const;
181 182

  /*!
Zhi committed
183 184 185 186 187 188 189 190 191 192 193
   * \brief Resolve the pass dependency. It globs all required passes by
   *        a given pass and executes them.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return The updated module after resolving pass dependencies.
   *
   * TODO(zhiics) Build a dependency graph among the passes using provided
   * metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
   * PassInfo, to store the relevant information including the parent passes.
   */
194
  void ResolveDependency(const IRModule& mod);
Zhi committed
195 196 197 198 199 200 201

  /*!
   * \brief Perform optimizations on a series of passes. The aforementioned
   *        typical pass manager jobs could be done by it. This function could
   *        be overloaded to focus on different metrics, i.e. performance,
   *        memory footprint, etc.
   *
202 203
   * \param mod The module that these passes are applied on.
   * \param pass_ctx The context that these passes execute on.
Zhi committed
204 205 206
   *
   * \return Return the updated module.
   */
207
  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
Zhi committed
208

209
  static constexpr const char* _type_key = "relay.Sequential";
210
  TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
Zhi committed
211 212
};

213 214 215
PassInfo::PassInfo(int opt_level,
                   std::string name,
                   tvm::Array<tvm::PrimExpr> required) {
216
  auto pass_info = make_object<PassInfoNode>();
Zhi committed
217 218 219
  pass_info->opt_level = opt_level;
  pass_info->name = std::move(name);
  pass_info->required = std::move(required);
220
  data_ = std::move(pass_info);
Zhi committed
221 222
}

223
ModulePass::ModulePass(
224
    runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
Zhi committed
225
    PassInfo pass_info) {
226
  auto n = make_object<ModulePassNode>();
Zhi committed
227 228
  n->pass_func = std::move(pass_func);
  n->pass_info = std::move(pass_info);
229
  data_ = std::move(n);
Zhi committed
230 231 232
}

// Module -> Module optimizations.
233
IRModule ModulePassNode::operator()(const IRModule& mod,
234
                                    const PassContext& pass_ctx) const {
235 236 237 238 239
  const PassInfo& pass_info = Info();
  DLOG(INFO) << "Executing module pass : "
             << pass_info->name
             << " with opt level: "
             << pass_info->opt_level;
Zhi committed
240
  CHECK(mod.defined());
241
  pass_ctx.Trace(mod, pass_info, true);
242
  IRModule updated_mod = pass_func(mod, pass_ctx);
Zhi committed
243
  CHECK(updated_mod.defined());
244
  pass_ctx.Trace(updated_mod, pass_info, false);
Zhi committed
245 246 247
  return updated_mod;
}

248
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
249
  auto n = make_object<SequentialNode>();
Zhi committed
250 251
  n->passes = std::move(passes);
  n->pass_info = std::move(pass_info);
252
  data_ = std::move(n);
253 254
}

255
Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
256
  auto n = make_object<SequentialNode>();
257
  n->passes = std::move(passes);
258
  PassInfo pass_info = PassInfo(2, std::move(name), {});
259
  n->pass_info = std::move(pass_info);
260
  data_ = std::move(n);
Zhi committed
261 262
}

263
const SequentialNode* Sequential::operator->() const {
264
  return static_cast<const SequentialNode*>(get());
Zhi committed
265 266
}

267
void SequentialNode::ResolveDependency(const IRModule& mod) {
Zhi committed
268 269 270 271 272 273 274 275
  // TODO(zhiics) Implement it.
  // 1. Consider the required passes for each pass.
  // 2. Only resolve the enabled passes.
  // 3. Build a dependency graph. Probably we need to update the pass list.
  LOG(FATAL) << "Pass dependency has not been resolved yet."
             << "\n";
}

276
// linearly scan the pass array to match pass_name
277
inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
278 279
                              const std::string& pass_name) {
  for (auto x : pass_array) {
280
    auto* str_name = x.as<tir::StringImmNode>();
281 282
    CHECK(str_name) << "pass name must be str";
    if (str_name->value == pass_name) return true;
Zhi committed
283
  }
284
  return false;
Zhi committed
285 286
}

287
bool SequentialNode::PassEnabled(const PassInfo& info) const {
288 289
  PassContext ctx = PassContext::Current();

290
  if (PassArrayContains(ctx->disabled_pass, info->name)) {
291 292 293
    return false;
  }

294
  if (PassArrayContains(ctx->required_pass, info->name)) {
295 296
    return true;
  }
297 298

  return ctx->opt_level >= info->opt_level;
299 300
}

301 302
Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
303 304 305 306 307 308 309 310 311
  const runtime::PackedFunc* f = nullptr;
  if (pass_name.find("transform.") != std::string::npos) {
    f = Registry::Get(pass_name);
  } else if ((f = Registry::Get("transform." + pass_name))) {
    // pass
  } else if ((f = Registry::Get("relay._transform." + pass_name))) {
  }
  CHECK(f != nullptr) << "Cannot use " << pass_name
                      << "to create the pass";
312 313 314
  return (*f)();
}

315 316
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
317
// ordering problem needs to be handled in the future.
318
IRModule SequentialNode::operator()(const IRModule& module,
319
                                    const PassContext& pass_ctx) const {
320
  IRModule mod = module;
321 322
  for (const Pass& pass : passes) {
    CHECK(pass.defined()) << "Found undefined pass for optimization.";
323 324 325 326
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    // resolve dependencies
    for (const auto& it : pass_info->required) {
327
      const auto* name = it.as<tvm::tir::StringImmNode>();
328 329
      CHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
330
    }
331
    mod = pass(mod, pass_ctx);
332 333
  }
  return mod;
Zhi committed
334 335 336
}

Pass CreateModulePass(
337
    const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
Zhi committed
338 339
    int opt_level,
    const std::string& name,
340
    const tvm::Array<tvm::PrimExpr>& required) {
341 342
  PassInfo pass_info = PassInfo(opt_level, name, required);
  return ModulePass(pass_func, pass_info);
Zhi committed
343 344 345 346
}

TVM_REGISTER_NODE_TYPE(PassInfoNode);

347
TVM_REGISTER_GLOBAL("transform.PassInfo")
348 349 350
.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
  return PassInfo(opt_level, name, required);
});
Zhi committed
351

352
TVM_REGISTER_GLOBAL("transform.Info")
Zhi committed
353 354 355 356 357
.set_body([](TVMArgs args, TVMRetValue* ret) {
  Pass pass = args[0];
  *ret = pass->Info();
});

358 359
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
360
  auto* node = static_cast<const PassInfoNode*>(ref.get());
Zhi committed
361 362 363 364 365
  p->stream << "The meta data of the pass: ";
  p->stream << "pass name: " << node->name;
  p->stream << "opt_level: " << node->opt_level;
  p->stream << "required passes: [" << "\n";
  for (const auto& it : node->required) {
366
    const auto* str = it.as<tvm::tir::StringImmNode>();
Zhi committed
367 368 369 370 371 372 373
    p->stream << str->value << ", ";
  }
  p->stream << "]\n";
});

TVM_REGISTER_NODE_TYPE(ModulePassNode);

374
TVM_REGISTER_GLOBAL("transform.MakeModulePass")
375 376 377 378 379
.set_body_typed(
  [](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
     PassInfo pass_info) {
  return ModulePass(pass_func, pass_info);
});
Zhi committed
380

381
TVM_REGISTER_GLOBAL("transform.RunPass")
Zhi committed
382
.set_body([](TVMArgs args, TVMRetValue* ret) {
383
  Pass pass = args[0];
384
  IRModule mod = args[1];
385
  *ret = pass(mod);
Zhi committed
386 387
});

388 389
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
390
  auto* node = static_cast<const ModulePassNode*>(ref.get());
391 392 393
  const PassInfo info = node->Info();
  p->stream << "Run Module pass: " << info->name
            << " at the optimization level " << info->opt_level;
Zhi committed
394 395
});

396
TVM_REGISTER_NODE_TYPE(SequentialNode);
Zhi committed
397

398
TVM_REGISTER_GLOBAL("transform.Sequential")
Zhi committed
399 400 401 402
.set_body([](TVMArgs args, TVMRetValue* ret) {
  tvm::Array<Pass> passes = args[0];
  int opt_level = args[1];
  std::string name = args[2];
403
  tvm::Array<tvm::PrimExpr> required = args[3];
404
  PassInfo pass_info = PassInfo(opt_level, name, required);
405
  *ret = Sequential(passes, pass_info);
Zhi committed
406 407
});

408 409
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
410
  auto* node = static_cast<const SequentialNode*>(ref.get());
411 412 413
  const PassInfo info = node->Info();
  p->stream << "Run Sequential pass: " << info->name
            << " at the optimization level " << info->opt_level << ". ";
Zhi committed
414 415
  p->stream << "The passes will be executed are: [";
  for (const auto& it : node->passes) {
416 417
    const PassInfo pass_info = it->Info();
    p->stream << pass_info->name << " ";
Zhi committed
418 419 420 421 422 423
  }
  p->stream << "]";
});

TVM_REGISTER_NODE_TYPE(PassContextNode);

424
TVM_REGISTER_GLOBAL("transform.PassContext")
425
.set_body([](TVMArgs args, TVMRetValue* ret) {
426
  auto pctx = PassContext::Create();
427 428
  int opt_level = args[0];
  int fallback_device = args[1];
429 430
  tvm::Array<tvm::PrimExpr> required = args[2];
  tvm::Array<tvm::PrimExpr> disabled = args[3];
431
  TraceFunc trace_func = args[4];
432 433 434 435
  pctx->opt_level = opt_level;
  pctx->fallback_device = fallback_device;
  pctx->required_pass = std::move(required);
  pctx->disabled_pass = std::move(disabled);
436
  pctx->trace_func = std::move(trace_func);
437
  *ret = pctx;
438
});
Zhi committed
439

440 441
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
442
  auto* node = static_cast<const PassContextNode*>(ref.get());
443 444
  p->stream << "Pass context information: " << "\n";
  p->stream << "\topt_level: " << node->opt_level << "\n";
445 446
  p->stream << "\tfallback device: "
            << runtime::DeviceName(node->fallback_device)
447 448 449 450 451 452 453 454 455 456 457 458 459
            << "\n";

  p->stream << "\trequired passes: [" << node->opt_level;
  for (const auto& it : node->required_pass) {
    p->stream << it << " ";
  }
  p->stream << "]\n";

  p->stream << "\tdisabled passes: [" << node->opt_level;
  for (const auto& it : node->disabled_pass) {
    p->stream << it << " ";
  }
  p->stream << "]";
Zhi committed
460 461
});

462 463 464 465 466 467 468 469 470 471 472
class PassContext::Internal {
 public:
  static void EnterScope(PassContext pass_ctx) {
    pass_ctx.EnterWithScope();
  }

  static void ExitScope(PassContext pass_ctx) {
    pass_ctx.ExitWithScope();
  }
};

473
TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext")
474 475
.set_body_typed(PassContext::Current);

476
TVM_REGISTER_GLOBAL("transform.EnterPassContext")
477 478
.set_body_typed(PassContext::Internal::EnterScope);

479
TVM_REGISTER_GLOBAL("transform.ExitPassContext")
480 481
.set_body_typed(PassContext::Internal::ExitScope);

482
}  // namespace transform
Zhi committed
483
}  // namespace tvm