tsim_driver.cc 6.31 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
22
#include <vta/driver.h>
23 24 25 26 27 28
#include <vta/dpi/module.h>

namespace vta {
namespace tsim {

using tvm::runtime::Module;
29
using vta::dpi::DPIModuleNode;
30

31 32
class Profiler {
 public:
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  Profiler() {
    counters_ = new int[num_counters_];
    this->ClearAll();
  }

  ~Profiler() {
    delete [] counters_;
  }

  /*! \brief update one event counter */
  void Update(uint32_t idx, uint32_t value) {
    counters_[idx] += value;
  }

  /*! \brief clear one event counter*/
  void Clear(uint32_t idx) {
    counters_[idx] = 0;
  }

  /*! \brief clear all event counters */
  void ClearAll() {
    for (uint32_t i = 0; i < num_counters_; i++) {
      counters_[i] = 0;
    }
  }

  /*! \brief return counters as json */
  std::string AsJSON() {
    std::ostringstream os;
    os << "{\n"
       << " \"cycle_count\":" << counters_[0] << "\n"
       <<"}\n";
    return os.str();
  }
67 68 69 70 71

  static Profiler* Global() {
    static Profiler inst;
    return &inst;
  }
72 73 74 75 76 77

 private:
  /*! \brief total number of event counters */
  uint32_t num_counters_{1};
  /*! \brief event counters */
  int* counters_{nullptr};
78 79
};

80 81
class DPILoader {
 public:
82 83 84 85 86
  ~DPILoader() {
    dpi_->SimResume();
    dpi_->SimFinish();
  }

87 88
  void Init(Module module) {
    mod_ = module;
89 90 91
    dpi_ = this->Get();
    dpi_->SimLaunch();
    dpi_->SimWait();
92 93 94 95 96 97 98 99 100 101 102
  }

  DPIModuleNode* Get() {
    return static_cast<DPIModuleNode*>(mod_.operator->());
  }

  static DPILoader* Global() {
    static DPILoader inst;
    return &inst;
  }

103
  // TVM module
104
  Module mod_;
105 106
  // DPI Module
  DPIModuleNode* dpi_{nullptr};
107 108 109 110 111
};

class Device {
 public:
  Device() {
112
    loader_ = DPILoader::Global();
113
    prof_ = Profiler::Global();
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  }

  int Run(vta_phy_addr_t insn_phy_addr,
          vta_phy_addr_t uop_phy_addr,
          vta_phy_addr_t inp_phy_addr,
          vta_phy_addr_t wgt_phy_addr,
          vta_phy_addr_t acc_phy_addr,
          vta_phy_addr_t out_phy_addr,
          uint32_t insn_count,
          uint32_t wait_cycles) {
    this->Init();
    this->Launch(insn_phy_addr,
                 uop_phy_addr,
                 inp_phy_addr,
                 wgt_phy_addr,
                 acc_phy_addr,
                 out_phy_addr,
                 insn_count,
                 wait_cycles);
    this->WaitForCompletion(wait_cycles);
    return 0;
  }

 private:
  void Init() {
139 140
    dpi_ = loader_->Get();
    dpi_->SimResume();
141 142 143 144 145 146 147 148 149 150
  }

  void Launch(vta_phy_addr_t insn_phy_addr,
              vta_phy_addr_t uop_phy_addr,
              vta_phy_addr_t inp_phy_addr,
              vta_phy_addr_t wgt_phy_addr,
              vta_phy_addr_t acc_phy_addr,
              vta_phy_addr_t out_phy_addr,
              uint32_t insn_count,
              uint32_t wait_cycles) {
151 152 153 154 155 156 157 158 159 160 161 162 163 164
    dpi_->WriteReg(0x04, 0);
    dpi_->WriteReg(0x08, insn_count);
    dpi_->WriteReg(0x0c, insn_phy_addr);
    dpi_->WriteReg(0x10, insn_phy_addr >> 32);
    dpi_->WriteReg(0x14, 0);
    dpi_->WriteReg(0x18, uop_phy_addr >> 32);
    dpi_->WriteReg(0x1c, 0);
    dpi_->WriteReg(0x20, inp_phy_addr >> 32);
    dpi_->WriteReg(0x24, 0);
    dpi_->WriteReg(0x28, wgt_phy_addr >> 32);
    dpi_->WriteReg(0x2c, 0);
    dpi_->WriteReg(0x30, acc_phy_addr >> 32);
    dpi_->WriteReg(0x34, 0);
    dpi_->WriteReg(0x38, out_phy_addr >> 32);
165
    // start
166
    dpi_->WriteReg(0x00, 0x1);
167 168 169 170 171
  }

  void WaitForCompletion(uint32_t wait_cycles) {
    uint32_t i, val;
    for (i = 0; i < wait_cycles; i++) {
172
      val = dpi_->ReadReg(0x00);
173 174 175
      val &= 0x2;
      if (val == 0x2) break;  // finish
    }
176 177
    prof_->Update(0, dpi_->ReadReg(0x04));
    dpi_->SimWait();
178 179
  }

180 181 182
  // Profiler
  Profiler* prof_;
  // DPI loader
183
  DPILoader* loader_;
184
  // DPI Module
185
  DPIModuleNode* dpi_;
186 187 188 189 190
};

using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;

191
TVM_REGISTER_GLOBAL("vta.tsim.init")
192 193 194 195 196
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    DPILoader::Global()->Init(m);
  });

197 198 199 200 201 202
TVM_REGISTER_GLOBAL("vta.tsim.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Profiler::Global()->ClearAll();
  });

TVM_REGISTER_GLOBAL("vta.tsim.profiler_status")
203
.set_body([](TVMArgs args, TVMRetValue* rv) {
204
    *rv = Profiler::Global()->AsJSON();
205 206
  });

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
}  // namespace tsim
}  // namespace vta

void* VTAMemAlloc(size_t size, int cached) {
  void *p = malloc(size);
  return p;
}

void VTAMemFree(void* buf) {
  free(buf);
}

vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
  return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
}

223 224 225 226 227 228 229 230
void VTAMemCopyFromHost(void* dst, const void* src, size_t size) {
  memcpy(dst, src, size);
}

void VTAMemCopyToHost(void* dst, const void* src, size_t size) {
  memcpy(dst, src, size);
}

231
void VTAFlushCache(void* vir_addr, vta_phy_addr_t phy_addr, int size) {
232 233
}

234
void VTAInvalidateCache(void* vir_addr, vta_phy_addr_t phy_addr, int size) {
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
}

VTADeviceHandle VTADeviceAlloc() {
  return new vta::tsim::Device();
}

void VTADeviceFree(VTADeviceHandle handle) {
  delete static_cast<vta::tsim::Device*>(handle);
}

int VTADeviceRun(VTADeviceHandle handle,
                 vta_phy_addr_t insn_phy_addr,
                 vta_phy_addr_t uop_phy_addr,
                 vta_phy_addr_t inp_phy_addr,
                 vta_phy_addr_t wgt_phy_addr,
                 vta_phy_addr_t acc_phy_addr,
                 vta_phy_addr_t out_phy_addr,
                 uint32_t insn_count,
                 uint32_t wait_cycles) {
  return static_cast<vta::tsim::Device*>(handle)->Run(
      insn_phy_addr,
      uop_phy_addr,
      inp_phy_addr,
      wgt_phy_addr,
      acc_phy_addr,
      out_phy_addr,
      insn_count,
      wait_cycles);
}