compiler.cc 27.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file src/relay/backend/vm/compiler.cc
 * \brief A compiler from relay::Module to the VM byte code.
 */

#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
30
#include <tvm/relay/transform.h>
31 32 33 34 35 36 37
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
38
#include "../../pass/pass_util.h"
39 40 41

namespace tvm {
namespace relay {
42 43 44 45 46 47 48 49

namespace transform {

Pass LambdaLift();
Pass InlinePrimitives();

}  // namespace transform

50 51 52 53
namespace vm {

using namespace tvm::runtime;
using namespace tvm::runtime::vm;
54
using namespace relay::transform;
55 56 57 58 59 60 61 62 63 64 65

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);

template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
using TagMap = NodeMap<tvm::relay::Constructor, Index>;
using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
using GlobalMap = NodeMap<GlobalVar, Index>;
using ConstMap = NodeMap<Constant, Index>;
using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
66
using TargetsMap = Map<tvm::Integer, tvm::Target>;
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

struct VMCompilerContext {
  // The module context for the compilation
  Module module;
  // Error reporter
  ErrorReporter err_reporter;
  // Map from a unique integer to ADT constructor tag
  TagNameMap tag_index_map;
  // Map from ADT constructor tag to a unique integer
  TagMap tag_map;
  // Map from global var to a unique integer
  GlobalMap global_map;
  // Map from Const object to its index in const pool
  ConstMap const_map;
  // Map from Const tensor shape to its index in const pool
  ConstTensorShapeMap const_tensor_shape_map;
  // List of lowered functions
  std::vector<LoweredFunc> lowered_funcs;
85 86
  // The functions that have been lowered.
  std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
};

// Compute the constant pool, i.e a mapping from Constant node to constant index.
struct ConstantPool : ExprVisitor {
  std::set<GlobalVar> visited;
  Module module;
  ConstMap const_map;
  ConstTensorShapeMap const_tensor_shape_map;

  size_t index;

  explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}

  void VisitExpr_(const GlobalVarNode* var_node) {
    auto gvar = GetRef<GlobalVar>(var_node);
    if (visited.find(gvar) == visited.end()) {
      visited.insert(gvar);
      this->VisitExpr(this->module->Lookup(gvar));
    }
  }

  void VisitExpr_(const ConstantNode* const_node) {
    auto konst = GetRef<Constant>(const_node);
    auto it = this->const_map.find(konst);
    if (it == this->const_map.end()) {
      this->const_map.insert({konst, index++});
    }
  }
};

std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
  auto cp = ConstantPool(module);
  for (auto& func : module->functions) {
    cp.VisitExpr(func.first);
  }
  return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
}

void InstructionPrint(std::ostream& os, const Instruction& instr);

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
  virtual ~MatchValue() {}
};
using MatchValuePtr = std::shared_ptr<MatchValue>;

// A runtime object that resides in a register
struct RegisterValue : MatchValue {
  // The register num
  RegName rergister_num;

  explicit RegisterValue(RegName reg) : rergister_num(reg) {}

  ~RegisterValue() {}
};

// The value is a field of another runtime object
struct AccessField : MatchValue {
  MatchValuePtr parent;
  // Field index
  size_t index;
  // Runtime register num after compiling the access field path
  RegName reg{-1};

  AccessField(MatchValuePtr parent, size_t index)
  : parent(parent), index(index) {}

  ~AccessField() {}
};

157 158 159 160 161 162 163 164
/*!
 * \brief Condition in a decision tree
 */
struct ConditionNode {
  virtual ~ConditionNode() {}
};

using ConditionNodePtr = std::shared_ptr<ConditionNode>;
165 166

/*!
167
 * \brief A var binding condition
168
 */
169 170 171
struct VarBinding : ConditionNode {
  Var var;
  MatchValuePtr val;
172

173 174
  VarBinding(Var var, MatchValuePtr val)
          : var(var), val(val) {}
175

176 177
  ~VarBinding() {}
};
178

179 180 181 182 183 184
/*!
 * \brief Compare the tag of the object
 */
struct TagCompare : ConditionNode {
  /*! \brief The object to be examined */
  MatchValuePtr obj;
185

186 187
  /*! \brief The expected tag */
  int target_tag;
188

189 190 191
  TagCompare(MatchValuePtr obj, size_t target)
          : obj(obj), target_tag(target) {
  }
192

193 194 195 196 197 198 199
  ~TagCompare() {}
};

using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
200

201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
                                         Pattern pattern,
                                         TreeNodePtr then_branch,
                                         TreeNodePtr else_branch) {
  if (pattern.as<PatternWildcardNode>()) {
    // We ignore wildcard binding since it's not producing new vars
    return then_branch;
  } else if (pattern.as<PatternVarNode>()) {
    auto pat = pattern.as<PatternVarNode>();
    auto pattern = GetRef<PatternVar>(pat);
    auto cond = std::make_shared<VarBinding>(pattern->var, data);
    return TreeBranchNode::Make(cond, then_branch, else_branch);
  } else {
    auto pat = pattern.as<PatternConstructorNode>();
    auto pattern = GetRef<PatternConstructor>(pat);
    auto tag = pattern->constructor->tag;
217

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
    size_t field_index = 0;
    for (auto& p : pattern->patterns) {
      auto d = std::make_shared<AccessField>(data, field_index);
      then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
      field_index++;
    }
    auto cond = std::make_shared<TagCompare>(data, tag);
    return TreeBranchNode::Make(cond, then_branch, else_branch);
  }
}

TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
                                        Clause clause,
                                        TreeNodePtr else_branch) {
  return BuildDecisionTreeFromPattern(data, clause->lhs,
                                      TreeLeafNode::Make(clause->rhs), else_branch);
}

TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
  // When nothing matches, the VM throws fatal error
  TreeNodePtr else_branch = TreeLeafFatalNode::Make();
  // Start from the last clause
  for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
    else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
  }
  return else_branch;
}

class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
 public:
  VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets)
      : last_register_(0),
        registers_num_(0),
        engine_(CompileEngine::Global()),
        context_(context),
        targets_(targets) {}

  VMFunction Compile(const GlobalVar& var, const Function& func) {
    size_t i = 0;
    // We then assign register num to the free variables
    for (auto param : func->params) {
      auto arg_register = NewRegister();
      CHECK_EQ(i, arg_register);
      var_register_map_.insert({param, arg_register});
      params_.push_back(param->name_hint());
      ++i;
    }

    if (IsClosure(func)) {
      Function inner_func = Downcast<Function>(func->body);
      for (auto param : inner_func->params) {
        auto arg_register = NewRegister();
        CHECK_EQ(i, arg_register);
        var_register_map_.insert({param, arg_register});
        params_.push_back(param->name_hint());
        ++i;
      }
      this->VisitExpr(inner_func->body);
    } else {
      this->VisitExpr(func->body);
    }
    instructions_.push_back(Instruction::Ret(last_register_));
    return VMFunction(var->name_hint, params_, instructions_, registers_num_);
  }

 protected:
  size_t NewRegister() { return registers_num_++; }
285 286 287 288 289 290 291

  inline void Emit(const Instruction& instr) {
    DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
    CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
    switch (instr.op) {
      case Opcode::AllocDatatype:
      case Opcode::AllocTensor:
292
      case Opcode::AllocTensorReg:
293
      case Opcode::GetField:
294
      case Opcode::GetTag:
295
      case Opcode::LoadConst:
296
      case Opcode::LoadConsti:
297 298 299 300
      case Opcode::Invoke:
      case Opcode::AllocClosure:
      case Opcode::Move:
      case Opcode::InvokeClosure:
301
        last_register_ = instr.dst;
302 303
        break;
      case Opcode::InvokePacked:
304
        last_register_ = instr.packed_args[instr.arity - 1];
305 306 307 308
        break;
      case Opcode::If:
      case Opcode::Ret:
      case Opcode::Goto:
309
      case Opcode::Fatal:
310 311
        break;
    }
312
    instructions_.push_back(instr);
313 314 315 316
  }

  void VisitExpr_(const ConstantNode* const_node) {
    auto rconst = GetRef<Constant>(const_node);
317 318
    auto it = this->context_->const_map.find(rconst);
    CHECK(it != this->context_->const_map.end());
319 320 321 322 323
    Emit(Instruction::LoadConst(it->second, NewRegister()));
  }

  void VisitExpr_(const VarNode* var_node) {
    auto var = GetRef<Var>(var_node);
324 325 326
    auto reg_it = this->var_register_map_.find(var);
    CHECK(reg_it != this->var_register_map_.end());
    last_register_ = reg_it->second;
327 328 329 330 331 332 333 334
  }

  void VisitExpr_(const TupleNode* tuple_node) {
    auto tuple = GetRef<Tuple>(tuple_node);
    std::vector<Index> fields_registers;

    for (auto& field : tuple->fields) {
      this->VisitExpr(field);
335
      fields_registers.push_back(last_register_);
336 337 338 339 340 341 342 343 344 345 346 347
    }

    // TODO(@jroesch): use correct tag
    Emit(Instruction::AllocDatatype(
      0,
      tuple->fields.size(),
      fields_registers,
      NewRegister()));
  }

  void VisitExpr_(const MatchNode* match_node) {
    auto match = GetRef<Match>(match_node);
348 349

    this->VisitExpr(match->data);
350
    CompileMatch(match);
351 352 353
  }

  void VisitExpr_(const LetNode* let_node) {
354
    DLOG(INFO) << let_node->value;
355
    this->VisitExpr(let_node->value);
356
    var_register_map_.insert({let_node->var, this->last_register_});
357 358 359 360 361 362
    this->VisitExpr(let_node->body);
  }

  void VisitExpr_(const TupleGetItemNode* get_node) {
    auto get = GetRef<TupleGetItem>(get_node);
    this->VisitExpr(get->tuple);
363
    auto tuple_register = last_register_;
364 365 366 367
    Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
  }

  void VisitExpr_(const GlobalVarNode* gvar) {
368
    auto var = GetRef<GlobalVar>(gvar);
369 370 371
    auto func = context_->module->Lookup(var);
    auto it = context_->global_map.find(var);
    CHECK(it != context_->global_map.end());
372 373
    // Allocate closure with zero free vars
    Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
374 375 376 377 378
  }

  void VisitExpr_(const IfNode* if_node) {
    this->VisitExpr(if_node->cond);

379
    size_t test_register = last_register_;
380

381
    this->Emit(Instruction::LoadConsti(1, NewRegister()));
382 383
    auto after_cond = instructions_.size();
    auto target_register = last_register_;
384
    this->Emit(Instruction::If(test_register, target_register, 0, 0));
385 386
    this->VisitExpr(if_node->true_branch);

387
    size_t true_register = last_register_;
388 389 390 391
    Emit(Instruction::Goto(0));

    // Finally store how many instructions there are in the
    // true branch.
392
    auto after_true = this->instructions_.size();
393 394 395

    this->VisitExpr(if_node->false_branch);

396
    size_t false_register = last_register_;
397

398 399
    // In else-branch, override the then-branch register
    Emit(Instruction::Move(false_register, true_register));
400 401
    // Compute the total number of instructions
    // after generating false.
402
    auto after_false = this->instructions_.size();
403 404 405 406 407 408 409 410 411

    // Now we will compute the jump targets in order
    // to properly patch the instruction with the
    // the requiste targets.

    // After we emit the true body, and false body,
    // we patch up the if instruction, and goto.
    auto true_offset = 1;
    auto false_offset = after_true - after_cond;
412 413
    instructions_[after_cond].if_op.true_offset = true_offset;
    instructions_[after_cond].if_op.false_offset = false_offset;
414 415

    // Patch the Goto.
416
    this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
417

418
    this->last_register_ = true_register;
419 420 421
  }

  Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
422
    TVMType dltype = Type2TVMType(ttype->dtype);
423
    auto tensor_type = GetRef<TensorType>(ttype);
424 425 426
    std::vector<int64_t> shape;
    for (auto dim : tensor_type->shape) {
      shape.push_back(Downcast<tvm::Integer>(dim)->value);
427
    }
428
    return Instruction::AllocTensor(shape, dltype, NewRegister());
429 430
  }

431 432
  void EmitInvokePrimitive(const Function& func,
                           const std::vector<Index>& args_registers,
433
                           const Type& ret_type) {
434
    std::vector<Index> unpacked_arg_regs;
435
    std::vector<Instruction> allocs;
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461

    // Arity calculation must flatten tuples.
    size_t arity = 0;
    CHECK_EQ(func->params.size(), args_registers.size());
    for (size_t i = 0; i < func->params.size(); i++) {
      auto ty = func->params[i]->checked_type();
      if (ty.as<TensorTypeNode>()) {
        unpacked_arg_regs.push_back(args_registers[i]);
        arity += 1;
      } else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
        for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
          const auto& field = tuple_ty->fields[f];
          CHECK(field.as<TensorTypeNode>())
            << "only supports non-nested tuples currently "
            << "found " << field;
          auto dst =  NewRegister();
          Emit(Instruction::GetField(args_registers[i], f, dst));
          unpacked_arg_regs.push_back(dst);
        }
        arity += tuple_ty->fields.size();
      } else {
        LOG(FATAL) << "unsupported parameter type " << ty;
      }
    }

    size_t return_val_count = 0;
462 463 464 465
    if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
      // Allocate space for the return tensor.
      auto alloc = AllocTensorFromType(ttype);
      allocs.push_back(alloc);
466
      return_val_count = 1;
467 468 469 470 471 472 473 474 475
    } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
      std::vector<Index> fields_registers;

      for (size_t i = 0; i < ttype->fields.size(); ++i) {
        auto f = ttype->fields[i];
        auto f_type = f.as<TensorTypeNode>();
        allocs.push_back(AllocTensorFromType(f_type));
        fields_registers.push_back(allocs.back().dst);
      }
476
      return_val_count = ttype->fields.size();
477 478 479 480
    } else {
      LOG(FATAL) << "Unsupported return value type";
    }

481
    arity += return_val_count;
482 483
    for (auto& alloc : allocs) {
      Emit(alloc);
484
      unpacked_arg_regs.push_back(alloc.dst);
485 486 487 488
    }

    // Next generate the invoke instruction.
    CHECK(func->IsPrimitive());
489 490 491 492 493 494 495 496 497 498
    Target target;
    if (targets_.size() == 1) {
      // homogeneous execution.
      for (auto kv : targets_) {
        target = kv.second;
      }
    } else {
      // heterogeneous execution.
      LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
    }
499
    auto key = CCacheKeyNode::make(func, target);
500
    auto cfunc = engine_->Lower(key);
501 502 503
    // TODO(jroesch): support lowered funcs for multiple targets
    CHECK_EQ(cfunc->funcs.size(), 1);
    auto op_index = -1;
504 505 506 507
    if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
      op_index = context_->lowered_funcs.size();
      context_->lowered_funcs.push_back(cfunc->funcs[0]);
      context_->seen_funcs[cfunc->funcs[0]] = op_index;
508
    } else {
509
      op_index = context_->seen_funcs[cfunc->funcs[0]];
510 511
    }

512 513 514
    Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));

    if (return_val_count > 1) {
515 516
      // return value is a tuple, we need to create a tuple
      std::vector<Index> fields_registers;
517 518
      for (size_t i = arity - return_val_count; i < arity; ++i) {
        fields_registers.push_back(unpacked_arg_regs[i]);
519
      }
520
      Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
521 522 523 524 525 526 527 528
    }
  }

  void VisitExpr_(const CallNode* call_node) {
    std::vector<Index> args_registers;

    for (auto arg : call_node->args) {
      this->VisitExpr(arg);
529
      args_registers.push_back(last_register_);
530 531 532 533 534 535 536 537 538
    }

    Expr op = call_node->op;

    if (auto func_node = op.as<FunctionNode>()) {
      CHECK(func_node->IsPrimitive());
      EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
    } else if (auto global_node = op.as<GlobalVarNode>()) {
      auto global = GetRef<GlobalVar>(global_node);
539 540
      auto it = context_->global_map.find(global);
      CHECK(it != context_->global_map.end());
541 542 543
      DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
                      << " with func_index=" << it->second;

544
      auto func = context_->module->Lookup(global);
545 546
      if (IsClosure(func)) {
        auto arity = func->params.size();
547
        Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
548 549 550 551 552
      } else {
        Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
      }
    } else if (auto constructor_node = op.as<ConstructorNode>()) {
      auto constructor = GetRef<Constructor>(constructor_node);
553 554
      Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
                                      NewRegister()));
555 556
    } else if (auto var_node = op.as<VarNode>()) {
      VisitExpr(GetRef<Var>(var_node));
557
      Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
558 559 560 561 562 563 564 565 566 567 568 569 570
    } else {
      LOG(FATAL) << "unsupported case in vm compiler: " << op;
    }
  }

  void VisitExpr_(const FunctionNode* func_node) {
    if (!func_node->IsPrimitive()) {
      LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
                 << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
                 << "AST: " << GetRef<Function>(func_node);
    }
  }

571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
  /*!
   * \brief Compile a match value
   * Generate byte code that compute the value specificed in val
   *
   * \return The register number assigned for the final value
   */
  RegName CompileMatchValue(MatchValuePtr val) {
    if (std::dynamic_pointer_cast<RegisterValue>(val)) {
      auto r = std::dynamic_pointer_cast<RegisterValue>(val);
      return r->rergister_num;
    } else {
      auto path = std::dynamic_pointer_cast<AccessField>(val);
      auto p = CompileMatchValue(path->parent);
      Emit(Instruction::GetField(p, path->index, NewRegister()));
      path->reg = last_register_;
      return path->reg;
587 588 589
    }
  }

590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
  void CompileTreeNode(TreeNodePtr tree) {
    if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
      auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
      VisitExpr(node->body);
    } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
      Emit(Instruction::Fatal());
    } else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
      auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
      if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
        // For Tag compariton, generate branches
        auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
        auto r = CompileMatchValue(cond->obj);
        Emit(Instruction::GetTag(r, NewRegister()));
        auto operand1 = last_register_;
        Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
        auto operand2 = last_register_;

        Emit(Instruction::If(operand1, operand2, 1, 0));
        auto cond_offset = instructions_.size() - 1;
        CompileTreeNode(node->then_branch);
        auto if_reg = last_register_;
        Emit(Instruction::Goto(1));
        auto goto_offset = instructions_.size() - 1;
        CompileTreeNode(node->else_branch);
        auto else_reg = last_register_;
        Emit(Instruction::Move(else_reg, if_reg));
        last_register_ = if_reg;
        auto else_offset = instructions_.size() - 1;
        // Fixing offsets
        instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
        instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
      } else {
        // For other non-branch conditions, move to then_branch directly
        auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
        var_register_map_[cond->var] = CompileMatchValue(cond->val);
        CompileTreeNode(node->then_branch);
      }
627 628 629
    }
  }

630 631 632 633 634 635 636 637 638 639 640 641 642
  /*!
   * \brief Compile a pattern match expression
   * It first converts the pattern match expression into a desicision tree, the condition
   * could be object comparison or variable binding. If any of the condition fails in a clause,
   * the decision tree switches to check the conditions of next clause and so on. If no clause
   * matches the value, a fatal node is inserted.
   *
   * After the decision tree is built, we convert it into bytecodes using If/Goto.
   */
  void CompileMatch(Match match) {
    auto data = std::make_shared<RegisterValue>(last_register_);
    auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
    CompileTreeNode(decision_tree);
643 644
  }

645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
 protected:
  /*! \brief Store the expression a variable points to. */
  std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map_;
  /*! \brief Instructions in the VMFunction. */
  std::vector<Instruction> instructions_;
  /*! \brief Parameter names of the function. */
  std::vector<std::string> params_;
  /*! \brief Map from var to register number. */
  std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map_;
  /*! \brief Last used register number. */
  size_t last_register_;
  /*! \brief Total number of virtual registers allocated. */
  size_t registers_num_;
  /*! \brief Compiler engine to lower primitive functions. */
  CompileEngine engine_;
  /*! \brief Global shared meta data */
  VMCompilerContext* context_;
  /*! \brief Target devices. */
  TargetsMap targets_;
664 665 666
};


667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
class VMCompiler : public runtime::ModuleNode {
 public:
  PackedFunc GetFunction(const std::string& name,
                         const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    if (name == "compile") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        CHECK_EQ(args.num_args, 3);
        this->Compile(args[0], args[1], args[2]);
      });
    } else if (name == "get_vm") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        *rv = runtime::Module(vm_);
      });
    } else {
      LOG(FATAL) << "Unknown packed function: " << name;
      return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
    }
  }
685

686 687 688
  const char* type_key() const final {
    return "VMCompiler";
  }
689

690 691
  std::shared_ptr<VirtualMachine> GetVirtualMachine() const {
    return vm_;
692 693
  }

694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
  void Compile(const Module& mod_ref,
               const TargetsMap& targets,
               const tvm::Target& target_host) {
    CHECK_EQ(targets.size(), 1)
      << "Currently VM compiler doesn't support heterogeneous compilation";
    targets_ = targets;
    target_host_ = target_host;
    vm_ = std::make_shared<VirtualMachine>();

    // Run some optimizations first, this code should
    // be moved to pass manager.
    context_.module = OptimizeModule(mod_ref);

    // Populate the global map.
    //
    // This maps global variables to a global index
    // in the VMFunction table.
    PopulateGlobalMap();

    // Next we populate constant map.
    auto constant_analysis_result = LayoutConstantPool(context_.module);
    context_.const_map = std::get<0>(constant_analysis_result);
    context_.const_tensor_shape_map = std::get<1>(constant_analysis_result);

    // Next we get ready by allocating space for
    // the global state.
    vm_->functions.resize(context_.module->functions.size());
    vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size());

    for (auto pair : context_.const_map) {
      vm_->constants[pair.second] = Object::Tensor(pair.first->data);
    }
726

727 728 729
    for (auto pair : context_.const_tensor_shape_map) {
      vm_->constants[pair.second.first] = Object::Tensor(pair.second.second);
    }
730

731 732 733 734 735 736 737 738 739
    for (auto named_func : context_.module->functions) {
      auto gvar = named_func.first;
      auto func = named_func.second;
      VMFunctionCompiler func_compiler(&context_, targets_);
      auto vm_func = func_compiler.Compile(gvar, func);

      size_t func_index = context_.global_map.at(gvar);
      CHECK(func_index < vm_->functions.size());
      vm_->functions[func_index] = vm_func;
740
    }
741

742 743 744 745
#if USE_RELAY_DEBUG
    for (auto vm_func : vm_->functions) {
      DLOG(INFO) << vm_func << "-------------";
    }
746
#endif  // USE_RELAY_DEBUG
747

748
    LibraryCodegen();
749

750 751
    for (auto gv : context_.global_map) {
      vm_->global_map.insert({gv.first->name_hint, gv.second});
752 753 754
    }
  }

755 756 757 758 759 760 761 762 763 764 765 766
 protected:
  Module OptimizeModule(const Module& mod) {
    // TODO(@icemelon9): check number of targets and build config, add more optimization pass
    transform::Sequential seq({transform::SimplifyInference(),
                               transform::ToANormalForm(),
                               transform::InlinePrimitives(),
                               transform::LambdaLift(),
                               transform::InlinePrimitives(),
                               transform::FuseOps()});
    auto pass_ctx = transform::PassContext::Create();
    tvm::With<relay::transform::PassContext> ctx(pass_ctx);
    return seq(mod);
767 768
  }

769 770 771 772 773 774 775 776
  void PopulateGlobalMap() {
    // First we populate global map.
    size_t global_index = 0;
    for (auto named_func : context_.module->functions) {
      auto gvar = named_func.first;
      context_.global_map.insert({gvar, global_index++});
    }
  }
777

778
  void LibraryCodegen() {
779 780 781 782 783 784 785 786 787
    auto const& lowered_funcs = context_.lowered_funcs;
    if (lowered_funcs.size() == 0) {
      return;
    }
    // TODO(@icemelon9): support heterogeneous targets
    Target target;
    for (auto kv : targets_) {
      target = kv.second;
    }
788
    if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
789 790 791 792 793
      runtime::Module mod =
          (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target,
               target_host_);
      CHECK(mod.operator->());
      vm_->lib = mod;
794 795 796
    } else {
      LOG(FATAL) << "relay.backend.build is not registered";
    }
797
    size_t primitive_index = 0;
798
    for (auto lfunc : lowered_funcs) {
799
      vm_->primitive_map.insert({lfunc->name, primitive_index++});
800 801 802
    }
  }

803 804 805 806 807 808 809 810 811 812
 protected:
  /*! \brief Target devices. */
  TargetsMap targets_;
  /*! \brief Target host device. */
  tvm::Target target_host_;
  /*! \brief Global shared meta data */
  VMCompilerContext context_;
  /*! \brief Compiled virtual machine. */
  std::shared_ptr<VirtualMachine> vm_;
};
813

814 815 816
runtime::Module CreateVMCompiler() {
  std::shared_ptr<VMCompiler> exec = std::make_shared<VMCompiler>();
  return runtime::Module(exec);
817 818
}

819 820 821 822
TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  *rv = CreateVMCompiler();
});
823 824 825 826

}  // namespace vm
}  // namespace relay
}  // namespace tvm