compiler.cc 35.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/backend/vm/compiler.cc
 * \brief A compiler from relay::Module to the VM byte code.
 */

25
#include <tvm/te/operation.h>
26
#include <tvm/ir/error.h>
27 28
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
29
#include <tvm/relay/qnn/transform.h>
30
#include <tvm/support/logging.h>
31
#include <tvm/relay/transform.h>
32
#include <tvm/runtime/vm.h>
33
#include <tvm/relay/attrs/memory.h>
34
#include <tvm/driver/driver_api.h>
35

36
#include <iostream>
37 38 39
#include <memory>
#include <string>
#include <tuple>
40 41 42 43
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../backend/compile_engine.h"
44
#include "../../pass/pass_util.h"
45
#include "../../op/op_common.h"
46
#include "compiler.h"
47 48 49

namespace tvm {
namespace relay {
50 51 52 53 54

namespace transform {

Pass LambdaLift();
Pass InlinePrimitives();
55
Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
56

57 58 59 60 61 62
Pass ManifestAlloc(Target target_host) {
  auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
  CHECK(f != nullptr) << "could not load memory allocation pass";
  return (*f)(target_host);
}

63 64
}  // namespace transform

65 66 67 68
namespace vm {

using namespace tvm::runtime;
using namespace tvm::runtime::vm;
69
using namespace relay::transform;
70 71 72 73

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
  virtual ~MatchValue() {}
};
using MatchValuePtr = std::shared_ptr<MatchValue>;

// A runtime object that resides in a register
struct RegisterValue : MatchValue {
  // The register num
  RegName rergister_num;

  explicit RegisterValue(RegName reg) : rergister_num(reg) {}

  ~RegisterValue() {}
};

// The value is a field of another runtime object
struct AccessField : MatchValue {
  MatchValuePtr parent;
  // Field index
  size_t index;
  // Runtime register num after compiling the access field path
  RegName reg{-1};

  AccessField(MatchValuePtr parent, size_t index)
  : parent(parent), index(index) {}

  ~AccessField() {}
};

104 105 106 107 108 109 110
/*!
 * \brief Condition in a decision tree
 */
struct ConditionNode {
  virtual ~ConditionNode() {}
};

111
using ConditionObjectPtr = std::shared_ptr<ConditionNode>;
112 113

/*!
114
 * \brief A var binding condition
115
 */
116 117 118
struct VarBinding : ConditionNode {
  Var var;
  MatchValuePtr val;
119

120 121
  VarBinding(Var var, MatchValuePtr val)
          : var(var), val(val) {}
122

123 124
  ~VarBinding() {}
};
125

126 127 128 129 130 131
/*!
 * \brief Compare the tag of the object
 */
struct TagCompare : ConditionNode {
  /*! \brief The object to be examined */
  MatchValuePtr obj;
132

133 134
  /*! \brief The expected tag */
  int target_tag;
135

136 137 138
  TagCompare(MatchValuePtr obj, size_t target)
          : obj(obj), target_tag(target) {
  }
139

140 141 142
  ~TagCompare() {}
};

143 144 145 146
using TreeObjectPtr = typename relay::TreeNode<ConditionObjectPtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionObjectPtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionObjectPtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionObjectPtr>;
147

148
TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
149
                                         Pattern pattern,
150 151
                                         TreeObjectPtr then_branch,
                                         TreeObjectPtr else_branch) {
152 153 154
  if (pattern.as<PatternWildcardNode>()) {
    // We ignore wildcard binding since it's not producing new vars
    return then_branch;
155 156
  } else if (const auto* pvn = pattern.as<PatternVarNode>()) {
    auto cond = std::make_shared<VarBinding>(pvn->var, data);
157
    return TreeBranchNode::Make(cond, then_branch, else_branch);
158
  } else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
159
    auto tag = pcn->constructor->tag;
160

161
    size_t field_index = 0;
162
    for (auto& p : pcn->patterns) {
163 164 165 166 167 168
      auto d = std::make_shared<AccessField>(data, field_index);
      then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
      field_index++;
    }
    auto cond = std::make_shared<TagCompare>(data, tag);
    return TreeBranchNode::Make(cond, then_branch, else_branch);
169
  } else {
170 171
    const auto* pt = pattern.as<PatternTupleNode>();
    CHECK(pt) << "unhandled case: " << AsText(pattern, false);
172 173
    size_t field_index = 0;
    for (auto& p : pt->patterns) {
174
      auto d = std::make_shared<AccessField>(data, field_index++);
175 176 177
      then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
    }
    return then_branch;
178 179 180
  }
}

181
TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data,
182
                                        Clause clause,
183
                                        TreeObjectPtr else_branch) {
184 185 186 187
  return BuildDecisionTreeFromPattern(data, clause->lhs,
                                      TreeLeafNode::Make(clause->rhs), else_branch);
}

188
TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
189
  // When nothing matches, the VM throws fatal error
190
  TreeObjectPtr else_branch = TreeLeafFatalNode::Make();
191 192 193 194 195 196 197
  // Start from the last clause
  for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
    else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
  }
  return else_branch;
}

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
  std::vector<int64_t> raw_shape;
  DLTensor tensor = shape.ToDLPack()->dl_tensor;
  CHECK_EQ(tensor.ndim, 1u);
  CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;

  // TODO(@jroesch): we really need to standaridize the bit width of
  // all of the shape manipulating code.
  CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
  int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
  for (auto i = 0; i < tensor.shape[0]; i++) {
    raw_shape.push_back(int_ptr[i]);
  }
  return raw_shape;
}


std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
  std::vector<int64_t> raw_shape;
  DLTensor tensor = shape.ToDLPack()->dl_tensor;
  CHECK_EQ(tensor.ndim, 1u);
  CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;

  // TODO(@jroesch): we really need to standaridize the bit width of
  // all of the shape manipulating code.
  CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
  int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
  for (auto i = 0; i < tensor.shape[0]; i++) {
    raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
  }
  return raw_shape;
}

231 232
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
 public:
233
  VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
234 235 236 237
      : last_register_(0),
        registers_num_(0),
        engine_(CompileEngine::Global()),
        context_(context),
238 239
        targets_(targets),
        target_host_(target_host) {}
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270

  VMFunction Compile(const GlobalVar& var, const Function& func) {
    size_t i = 0;
    // We then assign register num to the free variables
    for (auto param : func->params) {
      auto arg_register = NewRegister();
      CHECK_EQ(i, arg_register);
      var_register_map_.insert({param, arg_register});
      params_.push_back(param->name_hint());
      ++i;
    }

    if (IsClosure(func)) {
      Function inner_func = Downcast<Function>(func->body);
      for (auto param : inner_func->params) {
        auto arg_register = NewRegister();
        CHECK_EQ(i, arg_register);
        var_register_map_.insert({param, arg_register});
        params_.push_back(param->name_hint());
        ++i;
      }
      this->VisitExpr(inner_func->body);
    } else {
      this->VisitExpr(func->body);
    }
    instructions_.push_back(Instruction::Ret(last_register_));
    return VMFunction(var->name_hint, params_, instructions_, registers_num_);
  }

 protected:
  size_t NewRegister() { return registers_num_++; }
271 272 273 274 275

  inline void Emit(const Instruction& instr) {
    DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
    CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
    switch (instr.op) {
276
      case Opcode::AllocADT:
277
      case Opcode::AllocTensor:
278
      case Opcode::AllocTensorReg:
279
      case Opcode::GetField:
280
      case Opcode::GetTag:
281
      case Opcode::LoadConst:
282
      case Opcode::LoadConsti:
283 284
      case Opcode::Invoke:
      case Opcode::AllocClosure:
285
      case Opcode::AllocStorage:
286 287
      case Opcode::Move:
      case Opcode::InvokeClosure:
288
        last_register_ = instr.dst;
289 290 291 292 293
        break;
      case Opcode::InvokePacked:
      case Opcode::If:
      case Opcode::Ret:
      case Opcode::Goto:
294
      case Opcode::Fatal:
295 296
        break;
    }
297
    instructions_.push_back(instr);
298 299 300
  }

  void VisitExpr_(const ConstantNode* const_node) {
301 302 303
    size_t konst_idx = context_->constants.size();
    context_->constants.push_back(const_node->data);
    Emit(Instruction::LoadConst(konst_idx, NewRegister()));
304 305 306 307
  }

  void VisitExpr_(const VarNode* var_node) {
    auto var = GetRef<Var>(var_node);
308 309 310
    auto reg_it = this->var_register_map_.find(var);
    CHECK(reg_it != this->var_register_map_.end());
    last_register_ = reg_it->second;
311 312 313 314 315 316 317 318
  }

  void VisitExpr_(const TupleNode* tuple_node) {
    auto tuple = GetRef<Tuple>(tuple_node);
    std::vector<Index> fields_registers;

    for (auto& field : tuple->fields) {
      this->VisitExpr(field);
319
      fields_registers.push_back(last_register_);
320 321 322
    }

    // TODO(@jroesch): use correct tag
323
    Emit(Instruction::AllocADT(
324 325 326 327 328 329 330 331
      0,
      tuple->fields.size(),
      fields_registers,
      NewRegister()));
  }

  void VisitExpr_(const MatchNode* match_node) {
    auto match = GetRef<Match>(match_node);
332 333

    this->VisitExpr(match->data);
334
    CompileMatch(match);
335 336 337
  }

  void VisitExpr_(const LetNode* let_node) {
338
    DLOG(INFO) << PrettyPrint(let_node->value);
339
    this->VisitExpr(let_node->value);
340
    var_register_map_.insert({let_node->var, this->last_register_});
341 342 343 344 345 346
    this->VisitExpr(let_node->body);
  }

  void VisitExpr_(const TupleGetItemNode* get_node) {
    auto get = GetRef<TupleGetItem>(get_node);
    this->VisitExpr(get->tuple);
347
    auto tuple_register = last_register_;
348 349 350 351
    Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
  }

  void VisitExpr_(const GlobalVarNode* gvar) {
352
    auto var = GetRef<GlobalVar>(gvar);
353 354 355
    auto func = context_->module->Lookup(var);
    auto it = context_->global_map.find(var);
    CHECK(it != context_->global_map.end());
356 357
    // Allocate closure with zero free vars
    Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
358 359 360 361 362
  }

  void VisitExpr_(const IfNode* if_node) {
    this->VisitExpr(if_node->cond);

363
    size_t test_register = last_register_;
364

365
    this->Emit(Instruction::LoadConsti(1, NewRegister()));
366 367
    auto after_cond = instructions_.size();
    auto target_register = last_register_;
368
    this->Emit(Instruction::If(test_register, target_register, 0, 0));
369 370
    this->VisitExpr(if_node->true_branch);

371
    size_t true_register = last_register_;
372 373 374 375
    Emit(Instruction::Goto(0));

    // Finally store how many instructions there are in the
    // true branch.
376
    auto after_true = this->instructions_.size();
377 378 379

    this->VisitExpr(if_node->false_branch);

380
    size_t false_register = last_register_;
381

382 383
    // In else-branch, override the then-branch register
    Emit(Instruction::Move(false_register, true_register));
384 385
    // Compute the total number of instructions
    // after generating false.
386
    auto after_false = this->instructions_.size();
387 388 389 390 391 392 393 394 395

    // Now we will compute the jump targets in order
    // to properly patch the instruction with the
    // the requiste targets.

    // After we emit the true body, and false body,
    // we patch up the if instruction, and goto.
    auto true_offset = 1;
    auto false_offset = after_true - after_cond;
396 397
    instructions_[after_cond].if_op.true_offset = true_offset;
    instructions_[after_cond].if_op.false_offset = false_offset;
398 399

    // Patch the Goto.
400
    this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
401

402
    this->last_register_ = true_register;
403 404
  }

405
  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
406 407 408 409 410 411 412 413 414 415 416 417 418
    // Lower shape function
    auto key = CCacheKeyNode::make(func, target_host_);
    auto cfunc = engine_->LowerShapeFunc(key);
    int op_index = -1;
    if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
      op_index = context_->cached_funcs.size();
      context_->cached_funcs.push_back(cfunc);
      context_->seen_funcs[cfunc->funcs[0]] = op_index;
    } else {
      op_index = context_->seen_funcs[cfunc->funcs[0]];
    }

    // Prepare input and output registers
419 420 421 422 423 424
    std::vector<Index> argument_registers;
    for (auto input : inputs) {
      auto reg = var_register_map_.find(Downcast<Var>(input));
      CHECK(reg != var_register_map_.end())
        << "internal error: all variables should be in the register mapping";
      argument_registers.push_back(reg->second);
425 426
    }

427 428 429 430 431
    for (auto output : outputs) {
      auto reg = var_register_map_.find(Downcast<Var>(output));
      CHECK(reg != var_register_map_.end())
        << "internal error: all variables should be in the register mapping";
      argument_registers.push_back(reg->second);
432
    }
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462

    Emit(Instruction::InvokePacked(op_index,
      argument_registers.size(),
      outputs.size(),
      argument_registers));
  }

  void EmitInvokeTVMOp(const Function& func,
                       const Expr& inputs,
                       const Expr& outputs) {
    std::vector<Index> argument_registers;

    CHECK(func->IsPrimitive())
      << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";

    auto input_tuple = inputs.as<TupleNode>();
    CHECK(input_tuple)
      << "internal error: invoke_tvm_op inputs must be a tuple,"
      << "please file a bug in the memory manifestation pass";

    auto output_tuple = outputs.as<TupleNode>();
    CHECK(output_tuple)
      << "internal error: invoke_tvm_op outputs must be a tuple,"
      << "please file a bug in the memory manifestation pass";

    for (auto input : input_tuple->fields) {
      auto reg = var_register_map_.find(Downcast<Var>(input));
      CHECK(reg != var_register_map_.end())
        << "internal error: all variables should be in the register mapping";
      argument_registers.push_back(reg->second);
463 464
    }

465 466 467 468 469
    for (auto output : output_tuple->fields) {
      auto reg = var_register_map_.find(Downcast<Var>(output));
      CHECK(reg != var_register_map_.end())
        << "internal error: all variables should be in the register mapping";
      argument_registers.push_back(reg->second);
470 471
    }

472
    Target target;
Zhi committed
473 474 475

    if (!func->UseDefaultCompiler()) {
      target = tvm::target::ext_dev();
476
    } else {
Zhi committed
477 478 479 480 481 482 483 484 485
      // Next generate the invoke instruction.
      if (targets_.size() == 1) {
        // homogeneous execution.
        const auto& it = targets_.begin();
        target = (*it).second;
      } else {
        // heterogeneous execution.
        LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
      }
486
    }
487

488
    auto key = CCacheKeyNode::make(func, target);
489
    auto cfunc = engine_->Lower(key);
490

491
    auto op_index = -1;
Zhi committed
492
    if (!func->UseDefaultCompiler()) {
493 494
      op_index = context_->cached_funcs.size();
      context_->cached_funcs.push_back(cfunc);
495
    } else {
Zhi committed
496 497 498 499 500 501 502 503 504
      // TODO(jroesch): support lowered funcs for multiple targets
      CHECK_EQ(cfunc->funcs.size(), 1);
      if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
        op_index = context_->cached_funcs.size();
        context_->cached_funcs.push_back(cfunc);
        context_->seen_funcs[cfunc->funcs[0]] = op_index;
      } else {
        op_index = context_->seen_funcs[cfunc->funcs[0]];
      }
505 506
    }

507 508 509 510
    Emit(Instruction::InvokePacked(op_index,
      argument_registers.size(),
      output_tuple->fields.size(),
      argument_registers));
511 512 513
  }

  void VisitExpr_(const CallNode* call_node) {
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
    Expr op = call_node->op;

    // First we handle the case in which we are using an opaque
    // operator used to define a sub-dialect, such as memory
    // allocation operations.
    if (op.as<OpNode>()) {
      OpMatch<void> matcher;
      matcher.Match("memory.invoke_tvm_op",
        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
          CHECK_EQ(args.size(), 3);
          EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
      }).Match("memory.alloc_tensor",
        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
          CHECK_EQ(args.size(), 2);

          // Get the attributes.
          auto alloc_attrs = attrs.as<AllocTensorAttrs>();
          CHECK(alloc_attrs != nullptr)
              << "must be the alloc tensor attrs";
          auto dtype = alloc_attrs->dtype;

          // The storage will be passed dynamically.
          this->VisitExpr(args[0]);
          auto storage_register = last_register_;

          // If the shape is constant then we will emit a static tensor allocation instruction.
          auto const_shape = args[1].as<ConstantNode>();

          if (const_shape) {
            NDArray shape = const_shape->data;
            std::vector<int64_t> raw_shape;
            DLTensor tensor = shape.ToDLPack()->dl_tensor;
            // TODO(@jroesch): we need to get an RFC done to standarize this
            if (tensor.dtype.bits == 64) {
              raw_shape = ToAllocTensorShape64(shape);
            } else if (tensor.dtype.bits == 32) {
              raw_shape = ToAllocTensorShape32(shape);
            } else {
              LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
            }

            // Add context field.
            Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
          } else {
            this->VisitExpr(args[1]);
            auto shape_register = last_register_;
            Emit(Instruction::AllocTensorReg(
              storage_register,
              shape_register,
              dtype,
              NewRegister()));
          }
      }).Match("memory.alloc_storage",
        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
          CHECK_EQ(args.size(), 2);
          // Compute the size of the allocation.
          this->VisitExpr(args[0]);
          auto size_register = last_register_;

          this->VisitExpr(args[1]);
          auto alignment_register = last_register_;

          // Get the dtype hint from the attributes.
          auto alloc_attrs = attrs.as<AllocTensorAttrs>();
          CHECK(alloc_attrs != nullptr)
              << "must be the alloc tensor attrs";
          auto dtype = alloc_attrs->dtype;

          Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister()));
      }).Match("memory.shape_func",
        [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
          CHECK_EQ(args.size(), 3);
          auto shape_func = Downcast<Function>(args[0]);
          auto inputs = Downcast<Tuple>(args[1]);
          auto outputs = Downcast<Tuple>(args[2]);
          EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
      }).Match("memory.kill",
        [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
          LOG(FATAL) << "memory.kill is not yet supported";
      });
      matcher(GetRef<Call>(call_node));
      return;
    }

    // In the case its not one of these specialized operators we will generate code
    // for one of the "standard" cases.
600 601 602 603
    std::vector<Index> args_registers;

    for (auto arg : call_node->args) {
      this->VisitExpr(arg);
604
      args_registers.push_back(last_register_);
605 606
    }

607 608 609 610 611
    if (auto global_node = op.as<GlobalVarNode>()) {
      // In the case we are invoking a global we need to find its
      // global ID, and then check whether it is closure invocation
      // or whether it is a standard global, and emit the correct
      // calling convention.
612
      auto global = GetRef<GlobalVar>(global_node);
613 614
      auto it = context_->global_map.find(global);
      CHECK(it != context_->global_map.end());
615 616
      DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
                      << " with func_index=" << it->second;
617 618 619 620 621 622 623

      // TODO(tvm-team):
      // Think about mixed call into global that is not a relay::Function
      // perhaps establish as an invariance(all functions in mod must be relay::Function)
      auto func = Downcast<Function>(context_->module->Lookup(global));


624 625
      if (IsClosure(func)) {
        auto arity = func->params.size();
626
        Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
627 628 629 630
      } else {
        Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
      }
    } else if (auto constructor_node = op.as<ConstructorNode>()) {
631 632
      // In the constructor case, we simply need to find its tag
      // and emit a call to allocate the data structure.
633
      auto constructor = GetRef<Constructor>(constructor_node);
634
      Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
635
                                 NewRegister()));
636
    } else if (auto var_node = op.as<VarNode>()) {
637 638
      // If we are calling a variable, it must be the case that it is a closure so we
      // emit invoke closure here.
639
      VisitExpr(GetRef<Var>(var_node));
640
      Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
641
    } else {
642 643 644 645
      // Finally if there are any other cases this is a bug.
      LOG(FATAL) << "internal error: unreachable code,"
                 << "should be transformed away by previous passes"
                 << PrettyPrint(GetRef<Expr>(call_node));
646 647 648 649 650 651 652 653 654 655 656
    }
  }

  void VisitExpr_(const FunctionNode* func_node) {
    if (!func_node->IsPrimitive()) {
      LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
                 << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
                 << "AST: " << GetRef<Function>(func_node);
    }
  }

657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
  /*!
   * \brief Compile a match value
   * Generate byte code that compute the value specificed in val
   *
   * \return The register number assigned for the final value
   */
  RegName CompileMatchValue(MatchValuePtr val) {
    if (std::dynamic_pointer_cast<RegisterValue>(val)) {
      auto r = std::dynamic_pointer_cast<RegisterValue>(val);
      return r->rergister_num;
    } else {
      auto path = std::dynamic_pointer_cast<AccessField>(val);
      auto p = CompileMatchValue(path->parent);
      Emit(Instruction::GetField(p, path->index, NewRegister()));
      path->reg = last_register_;
      return path->reg;
673 674 675
    }
  }

676
  void CompileTreeNode(TreeObjectPtr tree) {
677
    if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
678 679 680
      VisitExpr(node->body);
    } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
      Emit(Instruction::Fatal());
681 682
    } else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
      if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705
        // For Tag compariton, generate branches
        auto r = CompileMatchValue(cond->obj);
        Emit(Instruction::GetTag(r, NewRegister()));
        auto operand1 = last_register_;
        Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
        auto operand2 = last_register_;

        Emit(Instruction::If(operand1, operand2, 1, 0));
        auto cond_offset = instructions_.size() - 1;
        CompileTreeNode(node->then_branch);
        auto if_reg = last_register_;
        Emit(Instruction::Goto(1));
        auto goto_offset = instructions_.size() - 1;
        CompileTreeNode(node->else_branch);
        auto else_reg = last_register_;
        Emit(Instruction::Move(else_reg, if_reg));
        last_register_ = if_reg;
        auto else_offset = instructions_.size() - 1;
        // Fixing offsets
        instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
        instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
      } else {
        // For other non-branch conditions, move to then_branch directly
706 707
        auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
        var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
708 709
        CompileTreeNode(node->then_branch);
      }
710 711 712
    }
  }

713 714 715 716 717 718 719 720 721 722 723 724 725
  /*!
   * \brief Compile a pattern match expression
   * It first converts the pattern match expression into a desicision tree, the condition
   * could be object comparison or variable binding. If any of the condition fails in a clause,
   * the decision tree switches to check the conditions of next clause and so on. If no clause
   * matches the value, a fatal node is inserted.
   *
   * After the decision tree is built, we convert it into bytecodes using If/Goto.
   */
  void CompileMatch(Match match) {
    auto data = std::make_shared<RegisterValue>(last_register_);
    auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
    CompileTreeNode(decision_tree);
726 727
  }

728 729
 protected:
  /*! \brief Store the expression a variable points to. */
730
  std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> expr_map_;
731 732 733 734 735
  /*! \brief Instructions in the VMFunction. */
  std::vector<Instruction> instructions_;
  /*! \brief Parameter names of the function. */
  std::vector<std::string> params_;
  /*! \brief Map from var to register number. */
736
  std::unordered_map<Var, RegName, ObjectHash, ObjectEqual> var_register_map_;
737 738 739 740 741 742 743 744 745 746
  /*! \brief Last used register number. */
  size_t last_register_;
  /*! \brief Total number of virtual registers allocated. */
  size_t registers_num_;
  /*! \brief Compiler engine to lower primitive functions. */
  CompileEngine engine_;
  /*! \brief Global shared meta data */
  VMCompilerContext* context_;
  /*! \brief Target devices. */
  TargetsMap targets_;
747 748
  /*! \brief Host target. */
  Target target_host_;
749 750 751
};


752
PackedFunc VMCompiler::GetFunction(const std::string& name,
753
                                   const ObjectPtr<Object>& sptr_to_self) {
754
  if (name == "lower") {
755 756
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      CHECK_EQ(args.num_args, 3);
757
      IRModule mod = args[0];
758 759 760 761 762 763
      this->Lower(mod, args[1], args[2]);
    });
  } else if (name == "codegen") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      CHECK_EQ(args.num_args, 0);
      this->Codegen();
764
    });
765
  } else if (name == "get_executable") {
766
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
767
      *rv = runtime::Module(exec_);
768
    });
769 770 771 772 773 774 775
  } else if (name == "set_params") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      Map<std::string, Constant> params = args[0];
      for (const auto& kv : params) {
        this->SetParam(kv.first, kv.second->data);
      }
    });
776 777 778
  } else {
    LOG(FATAL) << "Unknown packed function: " << name;
    return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
779
  }
780
}
781

782 783 784 785 786 787 788 789
void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
  params_[name] = data_in;
}

relay::Function VMCompiler::BindParamsByName(
    relay::Function func,
    const std::unordered_map<std::string, runtime::NDArray>& params) {
  std::unordered_map<std::string, relay::Var> name_dict;
790
  std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
791 792 793 794 795 796 797 798
  for (auto arg : func->params) {
    const auto &name = arg->name_hint();
    if (name_dict.count(name)) {
      repeat_var.insert(arg);
    } else {
      name_dict[name] = arg;
    }
  }
799
  std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
  for (auto &kv : params) {
    if (name_dict.count(kv.first) == 0) {
      continue;
    }
    auto arg = name_dict.at(kv.first);
    if (repeat_var.count(arg)) {
      LOG(FATAL) << "Multiple args in the function have name " << kv.first;
    }
    bind_dict[arg] = ConstantNode::make(kv.second);
  }
  Expr bound_expr = relay::Bind(func, bind_dict);
  Function ret = Downcast<Function>(bound_expr);
  CHECK(ret.defined())
      << "The returning type is expected to be a Relay Function."
      << "\n";
  return ret;
}

818
void VMCompiler::Lower(IRModule mod,
819 820
                       const TargetsMap& targets,
                       const tvm::Target& target_host) {
821 822
  CHECK_EQ(targets.size(), 1)
    << "Currently VM compiler doesn't support heterogeneous compilation";
823
  if (params_.size()) {
824 825 826 827
    BaseFunc base_func = mod->Lookup("main");
    CHECK(base_func->IsInstance<FunctionNode>())
        << "VM compiler expects to compile relay::Function";
    auto f = BindParamsByName(Downcast<Function>(base_func), params_);
828 829 830
    auto gvar = mod->GetGlobalVar("main");
    mod->Add(gvar, f);
  }
831

832
  exec_ = make_object<Executable>();
833 834 835
  targets_ = targets;
  target_host_ = target_host;

836
  // Run the optimizations necessary to target the VM.
837
  context_.module = OptimizeModule(mod, targets_);
838 839 840 841 842 843 844 845 846

  // Populate the global map.
  //
  // This maps global variables to a global index
  // in the VMFunction table.
  PopulateGlobalMap();

  // Next we get ready by allocating space for
  // the global state.
847
  exec_->functions.resize(context_.module->functions.size());
848 849 850

  for (auto named_func : context_.module->functions) {
    auto gvar = named_func.first;
851 852 853 854 855 856 857 858 859
    if (auto* n = named_func.second.as<FunctionNode>()) {
      auto func = GetRef<Function>(n);
      VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
      auto vm_func = func_compiler.Compile(gvar, func);

      size_t func_index = context_.global_map.at(gvar);
      CHECK(func_index < exec_->functions.size());
      exec_->functions[func_index] = vm_func;
    }
860
  }
861

862
#if USE_RELAY_DEBUG
863
  for (auto vm_func : exec_->functions) {
864 865
    DLOG(INFO) << vm_func << "-------------";
  }
866
#endif  // USE_RELAY_DEBUG
867

868 869
  // populate constants
  for (auto data : context_.constants) {
870
    exec_->constants.push_back(data);
871 872
  }

873
  // update global function map
874
  for (auto gv : context_.global_map) {
875
    exec_->global_map.insert({gv.first->name_hint, gv.second});
876
  }
877 878 879 880 881 882 883 884 885 886

  // update primitive function map
  size_t primitive_index = 0;
  for (const auto& cfunc : context_.cached_funcs) {
    if (cfunc->target->str() == "ext_dev") {
      exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
    } else {
      exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
    }
  }
887
}
888

889
IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
890
  Array<Pass> pass_seqs;
891
  Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
892
  pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
893 894 895 896 897 898 899 900
  // Run all dialect legalization passes.
  pass_seqs.push_back(relay::qnn::transform::Legalize());

  // Legalize pass is restricted to homogeneous execution for now.
  if (targets.size() == 1) {
    pass_seqs.push_back(transform::Legalize());
  }

901 902 903 904
  // eta expand to support constructors in argument position
  pass_seqs.push_back(transform::EtaExpand(
    /* expand_constructor */ true, /* expand_global_var */ false));

905 906 907 908 909 910 911 912
  pass_seqs.push_back(transform::SimplifyInference());
  PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
    Expr expr = args[0];
    if (expr.as<CallNode>()) {
      auto call_node = expr.as<CallNode>();
      auto op_node = call_node->op.as<OpNode>();
      if (op_node->name == "cast") {
        auto attrs = call_node->attrs.as<CastAttrs>();
913
        if (attrs->dtype == DataType::Int(32)) {
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941
          *rv = true;
        }
      }
    }
    *rv = false;
  });
  pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
  pass_seqs.push_back(transform::InlinePrimitives());

  pass_seqs.push_back(transform::CombineParallelConv2D(3));
  pass_seqs.push_back(transform::CombineParallelDense(3));
  pass_seqs.push_back(transform::FoldConstant());
  pass_seqs.push_back(transform::FoldScaleAxis());
  pass_seqs.push_back(transform::CanonicalizeCast());
  pass_seqs.push_back(transform::CanonicalizeOps());

  // Alter layout transformation is only applied to homogeneous execution yet.
  if (targets.size() == 1) {
    pass_seqs.push_back(transform::AlterOpLayout());
  }

  pass_seqs.push_back(transform::FoldConstant());

  pass_seqs.push_back(transform::FuseOps());
  pass_seqs.push_back(transform::ToANormalForm());
  pass_seqs.push_back(transform::LambdaLift());
  pass_seqs.push_back(transform::InlinePrimitives());

942 943 944 945 946 947 948 949 950
  // Manifest the allocations.
  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
  // Compute away possibly introduced constant computation.
  pass_seqs.push_back(transform::FoldConstant());
  // Fuse the shape functions.
  pass_seqs.push_back(transform::FuseOps());
  // Manifest the allocations needed for the shape functions.
  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));

951 952 953
  transform::Sequential seq(pass_seqs);
  transform::PassContext pass_ctx = PassContext::Current();
  // TODO(wweic): Support heterogenous execution
954
  tvm::With<relay::transform::PassContext> ctx(pass_ctx);
955
  if (targets.size() == 1) {
956 957 958
    const auto& it = targets.begin();
    With<Target> tctx((*it).second);
    return seq(mod);
959
  }
960 961
  return seq(mod);
}
962

963 964 965 966 967 968
void VMCompiler::PopulateGlobalMap() {
  // First we populate global map.
  size_t global_index = 0;
  for (auto named_func : context_.module->functions) {
    auto gvar = named_func.first;
    context_.global_map.insert({gvar, global_index++});
969
  }
970
}
971

972
void VMCompiler::Codegen() {
973 974
  using tir::LoweredFunc;

975 976 977 978
  if (!context_.module.defined()) {
    LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
    return;
  }
979 980
  auto const &cached_funcs = context_.cached_funcs;
  if (cached_funcs.size() == 0) {
981 982
    return;
  }
Zhi committed
983 984
  std::unordered_map<std::string, Array<LoweredFunc>> funcs;
  for (auto& cfunc : cached_funcs) {
985
    std::string target_str = cfunc->target->str();
Zhi committed
986 987 988 989
    if (target_str == "ext_dev") {
      continue;
    } else if (funcs.count(target_str) == 0) {
      funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
990
    } else {
Zhi committed
991
      funcs[target_str].push_back(cfunc->funcs[0]);
992
    }
993
  }
994

Zhi committed
995 996 997 998 999
  auto compile_engine = CompileEngine::Global();
  auto ext_mods = compile_engine->LowerExternalFunctions();
  runtime::Module mod;
  if (funcs.size() > 0) {
    mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
1000 1001
    CHECK(mod.operator->());
  } else {
Zhi committed
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
    CHECK_EQ(ext_mods.size(), 1U)
        << "Expect to have a TVM DSOModule when multiple runtime modules exist";
  }
  if (!ext_mods.empty()) {
    if (funcs.size() == 0) {
      mod = ext_mods[0];
    } else {
      // Import all external runtime modules.
      for (auto it : ext_mods) {
        mod.Import(it);
      }
    }
1014
  }
Zhi committed
1015
  exec_->lib = mod;
1016
}
1017

1018
runtime::Module CreateVMCompiler() {
1019
  auto exec = make_object<VMCompiler>();
1020
  return runtime::Module(exec);
1021 1022
}

1023 1024 1025 1026
TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  *rv = CreateVMCompiler();
});
1027 1028 1029 1030

}  // namespace vm
}  // namespace relay
}  // namespace tvm