driver.cc 4.17 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>

24 25
#include "vmem/virtual_memory.h"

26 27 28 29 30 31
namespace vta {
namespace driver {

using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
class DPILoader {
 public:
  ~DPILoader() {
    dpi_->SimResume();
    dpi_->SimFinish();
  }

  void Init(Module module) {
    mod_ = module;
    dpi_ = this->Get();
    dpi_->SimLaunch();
    dpi_->SimWait();
  }

  DPIModuleNode* Get() {
    return static_cast<DPIModuleNode*>(mod_.operator->());
  }

  static DPILoader* Global() {
    static DPILoader inst;
    return &inst;
  }

  // TVM module
  Module mod_;
  // DPI Module
  DPIModuleNode* dpi_{nullptr};
};

61
class Device {
62
 public:
63 64
  Device() {
    loader_ = DPILoader::Global();
65 66
  }

67
  uint32_t Run(uint32_t c, DLTensor* a, DLTensor* b) {
68
    uint32_t cycles;
69 70 71 72 73
    uint32_t len = a->shape[0];
    size_t size = (a->dtype.bits >> 3) * len;
    a_ = this->MemAlloc(size);
    b_ = this->MemAlloc(size);
    this->MemCopyFromHost(a_, a->data, size);
74
    this->Init();
75
    this->Launch(c, len);
76
    cycles = this->WaitForCompletion();
77 78 79
    this->MemCopyToHost(b->data, b_, size);
    this->MemFree(a_);
    this->MemFree(b_);
80
    return cycles;
81 82 83
  }

 private:
84 85 86 87 88
  void Init() {
    dpi_ = loader_->Get();
    dpi_->SimResume();
  }

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  void* MemAlloc(size_t size) {
    void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size);
    return reinterpret_cast<void*>(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr));
  }

  void MemFree(void* buf) {
    void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast<uint64_t>(buf));
    vta::vmem::VirtualMemoryManager::Global()->Free(addr);
  }

  vta_phy_addr_t MemGetPhyAddr(void* buf) {
    return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
  }

  void MemCopyFromHost(void* dst, const void* src, size_t size) {
    vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size);
  }

  void MemCopyToHost(void* dst, const void* src, size_t size) {
    vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size);
  }

  void Launch(uint32_t c, uint32_t len) {
112
    dpi_->WriteReg(0x08, c);
113 114 115 116 117
    dpi_->WriteReg(0x0c, len);
    dpi_->WriteReg(0x10, this->MemGetPhyAddr(a_));
    dpi_->WriteReg(0x14, 0);
    dpi_->WriteReg(0x18, this->MemGetPhyAddr(b_));
    dpi_->WriteReg(0x1c, 0);
118
    dpi_->WriteReg(0x00, 0x1); // launch
119 120
  }

121
  uint32_t WaitForCompletion() {
122
    uint32_t i, val;
123
    for (i = 0; i < wait_cycles_; i++) {
124
      val = dpi_->ReadReg(0x00);
125
      if (val == 2) break; // finish
126
    }
127
    val = dpi_->ReadReg(0x04);
128
    dpi_->SimWait();
129
    return val;
130 131
  }

132
  // wait cycles
133
  uint32_t wait_cycles_{100000000};
134 135 136 137
  // DPI loader
  DPILoader* loader_{nullptr};
  // DPI Module
  DPIModuleNode* dpi_{nullptr};
138 139 140 141
  // input vm ptr
  void* a_{nullptr};
  // output vm ptr
  void* b_{nullptr};
142 143 144 145 146
};

using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;

147 148 149 150 151 152
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    DPILoader::Global()->Init(m);
  });

153 154
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
155
    Device dev_;
156 157
    DLTensor* A = args[0];
    DLTensor* B = args[1];
158 159
    uint32_t c = static_cast<int>(args[2]);
    uint32_t cycles = dev_.Run(c, A, B);
160
    *rv = static_cast<int>(cycles);
161 162 163 164
  });

}  // namespace driver
}  // namespace vta