vm.cc 4.91 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file src/runtime/vm/profiler/vm.cc
 * \brief The Relay debug virtual machine.
 */

#include <tvm/runtime/registry.h>
#include <tvm/runtime/vm.h>

#include <algorithm>
#include <chrono>
#include <iomanip>
#include <memory>
#include <numeric>
#include <string>
#include <vector>

#include "vm.h"

namespace tvm {
namespace runtime {
namespace vm {

PackedFunc VirtualMachineDebug::GetFunction(
    const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) {
  if (name == "get_stat") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      double total_duration = 0.0;
      std::ostringstream os;
      os << std::setw(30) << std::left << "#OpName"
         << "\t" << std::setw(10) << std::left << "#InvokeCount"
         << "\t"
         << "#Duration(us): Sum/Mean/Min/Max" << std::endl;

      for (auto kv : op_durations) {
        auto vals = op_durations[kv.first];
        auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
        auto mean = sum / static_cast<double>(vals.size());
        auto min_value = *std::min_element(vals.begin(), vals.end());
        auto max_value = *std::max_element(vals.begin(), vals.end());

        os << std::setw(30) << std::left << packed_index_map[kv.first] << "\t"
           << std::setw(10) << std::left << op_invokes[kv.first] << "\t"
           <<  sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;

        total_duration += sum;
      }
      os << "Total Duration " << total_duration << " us" << std::endl;
      *rv = os.str();
    });
  } else if (name == "init") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      CHECK_EQ(args.size() % 2, 0);
      std::vector<TVMContext> contexts;
      for (int i = 0; i < args.size() / 2; ++i) {
        TVMContext ctx;
        int device_type = args[i * 2];
        ctx.device_type = DLDeviceType(device_type);
        ctx.device_id = args[i * 2 + 1];
        contexts.push_back(ctx);
      }
      this->Init(contexts);
    });
  } else {
    return VirtualMachine::GetFunction(name, sptr_to_self);
  }
}

88 89 90 91
void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
  VirtualMachine::LoadExecutable(exec);
  CHECK(this->exec);
  for (auto kv : this->exec->primitive_map) {
92 93 94 95 96
    packed_index_map[kv.second] = kv.first;
    op_invokes[kv.second] = 0;
  }
}

97 98 99 100
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
  VirtualMachine::Init(ctxs);
}

101 102 103
void VirtualMachineDebug::InvokePacked(Index packed_index,
                                       const PackedFunc& func, Index arg_count,
                                       Index output_size,
104
                                       const std::vector<ObjectRef>& args) {
105 106
  CHECK(this->exec);
  auto ctx = this->GetParamsContext();
107 108 109 110 111
  // warmup
  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
                               args);
  TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);

112 113 114 115 116 117 118 119 120 121 122 123 124 125
  auto op_begin = std::chrono::high_resolution_clock::now();
  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
                               args);
  TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
  auto op_end = std::chrono::high_resolution_clock::now();
  double op_duration =
      std::chrono::duration_cast<std::chrono::duration<double> >(op_end -
                                                                 op_begin)
          .count();

  op_durations[packed_index].push_back(op_duration * 1e6);
  op_invokes[packed_index] += 1;
}

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
  std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
  vm->LoadExecutable(exec);
  return runtime::Module(vm);
}

TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  runtime::Module mod = args[0];
  const auto* exec = dynamic_cast<Executable*>(mod.operator->());
  CHECK(exec) << "Virtual machine has not been defined yet."
              << "\n";
  *rv = CreateVirtualMachineDebug(exec);
});

141 142 143
}  // namespace vm
}  // namespace runtime
}  // namespace tvm