tsim_driver.cc 5.54 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
22
#include <vta/driver.h>
23 24
#include <vta/dpi/module.h>

25 26
#include "../vmem/virtual_memory.h"

27 28 29 30
namespace vta {
namespace tsim {

using tvm::runtime::Module;
31
using vta::dpi::DPIModuleNode;
32

33 34
class Profiler {
 public:
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
  Profiler() {
    counters_ = new int[num_counters_];
    this->ClearAll();
  }

  ~Profiler() {
    delete [] counters_;
  }

  /*! \brief update one event counter */
  void Update(uint32_t idx, uint32_t value) {
    counters_[idx] += value;
  }

  /*! \brief clear one event counter*/
  void Clear(uint32_t idx) {
    counters_[idx] = 0;
  }

  /*! \brief clear all event counters */
  void ClearAll() {
    for (uint32_t i = 0; i < num_counters_; i++) {
      counters_[i] = 0;
    }
  }

  /*! \brief return counters as json */
  std::string AsJSON() {
    std::ostringstream os;
    os << "{\n"
       << " \"cycle_count\":" << counters_[0] << "\n"
       <<"}\n";
    return os.str();
  }
69 70 71 72 73

  static Profiler* Global() {
    static Profiler inst;
    return &inst;
  }
74 75 76 77 78 79

 private:
  /*! \brief total number of event counters */
  uint32_t num_counters_{1};
  /*! \brief event counters */
  int* counters_{nullptr};
80 81
};

82 83
class DPILoader {
 public:
84 85 86 87 88
  ~DPILoader() {
    dpi_->SimResume();
    dpi_->SimFinish();
  }

89 90
  void Init(Module module) {
    mod_ = module;
91 92 93
    dpi_ = this->Get();
    dpi_->SimLaunch();
    dpi_->SimWait();
94 95 96 97 98 99 100 101 102 103 104
  }

  DPIModuleNode* Get() {
    return static_cast<DPIModuleNode*>(mod_.operator->());
  }

  static DPILoader* Global() {
    static DPILoader inst;
    return &inst;
  }

105
  // TVM module
106
  Module mod_;
107 108
  // DPI Module
  DPIModuleNode* dpi_{nullptr};
109 110 111 112 113
};

class Device {
 public:
  Device() {
114
    loader_ = DPILoader::Global();
115
    prof_ = Profiler::Global();
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  }

  int Run(vta_phy_addr_t insn_phy_addr,
          uint32_t insn_count,
          uint32_t wait_cycles) {
    this->Init();
    this->Launch(insn_phy_addr,
                 insn_count,
                 wait_cycles);
    this->WaitForCompletion(wait_cycles);
    return 0;
  }

 private:
  void Init() {
131 132
    dpi_ = loader_->Get();
    dpi_->SimResume();
133 134 135 136 137
  }

  void Launch(vta_phy_addr_t insn_phy_addr,
              uint32_t insn_count,
              uint32_t wait_cycles) {
138 139
    dpi_->WriteReg(0x08, insn_count);
    dpi_->WriteReg(0x0c, insn_phy_addr);
140
    dpi_->WriteReg(0x10, 0);
141
    dpi_->WriteReg(0x14, 0);
142
    dpi_->WriteReg(0x18, 0);
143
    dpi_->WriteReg(0x1c, 0);
144
    dpi_->WriteReg(0x20, 0);
145
    // start
146
    dpi_->WriteReg(0x00, 0x1);
147 148 149 150 151
  }

  void WaitForCompletion(uint32_t wait_cycles) {
    uint32_t i, val;
    for (i = 0; i < wait_cycles; i++) {
152
      val = dpi_->ReadReg(0x00);
153 154 155
      val &= 0x2;
      if (val == 0x2) break;  // finish
    }
156 157
    prof_->Update(0, dpi_->ReadReg(0x04));
    dpi_->SimWait();
158 159
  }

160 161 162
  // Profiler
  Profiler* prof_;
  // DPI loader
163
  DPILoader* loader_;
164
  // DPI Module
165
  DPIModuleNode* dpi_;
166 167 168 169 170
};

using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;

171
TVM_REGISTER_GLOBAL("vta.tsim.init")
172 173 174 175 176
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    DPILoader::Global()->Init(m);
  });

177 178 179 180 181 182
TVM_REGISTER_GLOBAL("vta.tsim.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Profiler::Global()->ClearAll();
  });

TVM_REGISTER_GLOBAL("vta.tsim.profiler_status")
183
.set_body([](TVMArgs args, TVMRetValue* rv) {
184
    *rv = Profiler::Global()->AsJSON();
185 186
  });

187 188 189 190
}  // namespace tsim
}  // namespace vta

void* VTAMemAlloc(size_t size, int cached) {
191 192
  void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size);
  return reinterpret_cast<void*>(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr));
193 194 195
}

void VTAMemFree(void* buf) {
196 197
  void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast<uint64_t>(buf));
  vta::vmem::VirtualMemoryManager::Global()->Free(addr);
198 199 200 201 202 203
}

vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
  return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
}

204
void VTAMemCopyFromHost(void* dst, const void* src, size_t size) {
205
  vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size);
206 207 208
}

void VTAMemCopyToHost(void* dst, const void* src, size_t size) {
209
  vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size);
210 211
}

212
void VTAFlushCache(void* vir_addr, vta_phy_addr_t phy_addr, int size) {
213 214
}

215
void VTAInvalidateCache(void* vir_addr, vta_phy_addr_t phy_addr, int size) {
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
}

VTADeviceHandle VTADeviceAlloc() {
  return new vta::tsim::Device();
}

void VTADeviceFree(VTADeviceHandle handle) {
  delete static_cast<vta::tsim::Device*>(handle);
}

int VTADeviceRun(VTADeviceHandle handle,
                 vta_phy_addr_t insn_phy_addr,
                 uint32_t insn_count,
                 uint32_t wait_cycles) {
  return static_cast<vta::tsim::Device*>(handle)->Run(
      insn_phy_addr,
      insn_count,
      wait_cycles);
}