Unverified Commit ec3f09b3 by Tianqi Chen Committed by GitHub

[RUNTIME] Refactor to enable stackvm in runtime. (#1588)

parent edda6cc1
...@@ -29,6 +29,7 @@ tvm_option(USE_ROCM "Build with ROCM" OFF) ...@@ -29,6 +29,7 @@ tvm_option(USE_ROCM "Build with ROCM" OFF)
tvm_option(ROCM_PATH "The path to rocm" /opt/rocm) tvm_option(ROCM_PATH "The path to rocm" /opt/rocm)
tvm_option(USE_RPC "Build with RPC" ON) tvm_option(USE_RPC "Build with RPC" ON)
tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF) tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF)
tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF)
tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_RTTI "Build with RTTI" ON)
...@@ -97,7 +98,6 @@ file(GLOB COMPILER_SRCS ...@@ -97,7 +98,6 @@ file(GLOB COMPILER_SRCS
src/arithmetic/*.cc src/arithmetic/*.cc
src/autotvm/*.cc src/autotvm/*.cc
src/codegen/*.cc src/codegen/*.cc
src/codegen/stack_vm/*.cc
src/lang/*.cc src/lang/*.cc
src/pass/*.cc src/pass/*.cc
src/op/*.cc src/op/*.cc
...@@ -135,6 +135,16 @@ if(USE_RPC) ...@@ -135,6 +135,16 @@ if(USE_RPC)
list(APPEND RUNTIME_SRCS ${RUNTIME_RPC_SRCS}) list(APPEND RUNTIME_SRCS ${RUNTIME_RPC_SRCS})
endif(USE_RPC) endif(USE_RPC)
file(GLOB STACKVM_RUNTIME_SRCS src/runtime/stackvm/*.cc)
file(GLOB STACKVM_CODEGEN_SRCS src/codegen/stackvm/*.cc)
list(APPEND COMPILER_SRCS ${STACKVM_CODEGEN_SRCS})
if(USE_STACKVM_RUNTIME)
message(STATUS "Build with stackvm support in runtime...")
list(APPEND RUNTIME_SRCS ${STACKVM_RUNTIME_SRCS})
else()
list(APPEND COMPILER_SRCS ${STACKVM_RUNTIME_SRCS})
endif(USE_STACKVM_RUNTIME)
if(USE_GRAPH_RUNTIME) if(USE_GRAPH_RUNTIME)
message(STATUS "Build with Graph runtime support...") message(STATUS "Build with Graph runtime support...")
file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc) file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc)
......
...@@ -96,6 +96,7 @@ stage('Build') { ...@@ -96,6 +96,7 @@ stage('Build') {
echo set\\(USE_RPC ON\\) >> config.cmake echo set\\(USE_RPC ON\\) >> config.cmake
echo set\\(USE_SORT ON\\) >> config.cmake echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake
echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake
echo set\\(USE_BLAS openblas\\) >> config.cmake echo set\\(USE_BLAS openblas\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
......
...@@ -65,6 +65,9 @@ set(USE_OPENGL OFF) ...@@ -65,6 +65,9 @@ set(USE_OPENGL OFF)
# Whether enable RPC runtime # Whether enable RPC runtime
set(USE_RPC ON) set(USE_RPC ON)
# Whether embed stackvm into the runtime
set(USE_STACKVM_RUNTIME OFF)
# Whether enable tiny embedded graph runtime. # Whether enable tiny embedded graph runtime.
set(USE_GRAPH_RUNTIME ON) set(USE_GRAPH_RUNTIME ON)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./runtime/util.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -449,25 +450,6 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; ...@@ -449,25 +450,6 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
*/ */
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*! \brief The kind of structure field info */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic } // namespace intrinsic
// Reuse IR node defintiion from HalideIR // Reuse IR node defintiion from HalideIR
......
...@@ -21,7 +21,33 @@ namespace runtime { ...@@ -21,7 +21,33 @@ namespace runtime {
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) { inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes; return t.code == code && t.bits == bits && t.lanes == lanes;
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
// Forward declare the intrinsic id we need
// in structure fetch to enable stackvm in runtime
namespace tvm {
namespace ir {
namespace intrinsic {
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic
} // namespace ir
} // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_ #endif // TVM_RUNTIME_UTIL_H_
...@@ -90,9 +90,12 @@ class Module(ModuleBase): ...@@ -90,9 +90,12 @@ class Module(ModuleBase):
kwargs : dict, optiona; kwargs : dict, optiona;
Additional arguments passed to fcompile Additional arguments passed to fcompile
""" """
if self.type_key == "stacktvm": if self.type_key == "stackvm":
raise ValueError("Module[%s]: export_library requires llvm module," if not file_name.endswith(".stackvm"):
" did you build with LLVM enabled?" % self.type_key) raise ValueError("Module[%s]: can only be saved as stackvm format."
"did you build with LLVM enabled?" % self.type_key)
self.save(file_name)
return
if self.type_key != "llvm": if self.type_key != "llvm":
raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key) raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
......
...@@ -40,7 +40,6 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { ...@@ -40,7 +40,6 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
CHECK_EQ(im->imports().size(), 0U) CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierarchy"; << "Only support simply one-level hierarchy";
std::string tkey = im->type_key(); std::string tkey = im->type_key();
std::string bin;
stream->Write(tkey); stream->Write(tkey);
im->SaveToBinary(stream); im->SaveToBinary(stream);
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file stack_vm_module.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/codegen.h>
#include "./codegen_stack_vm.h"
namespace tvm {
namespace codegen {
class StackVMModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const {
return "stackvm";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
return GetFunction(entry_func_, sptr_to_self);
}
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const StackVM& vm = it->second;
// capture sptr_to_self to keep module node alive.
return PackedFunc([vm, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
vm(args);
});
}
std::string GetSource(const std::string& format) final {
std::ostringstream os;
for (const auto& kv : fmap_) {
os << "Function: " << kv.first << '\n';
os << kv.second;
}
return os.str();
}
static runtime::Module Build(const Array<LoweredFunc>& funcs) {
CHECK_NE(funcs.size(), 0U);
std::shared_ptr<StackVMModuleNode> n =
std::make_shared<StackVMModuleNode>();
for (LoweredFunc f : funcs) {
StackVM vm = codegen::CodeGenStackVM().Compile(f);
CHECK(!n->fmap_.count(f->name))
<< "Function name " << f->name << "already exist in list";
vm.mod_ctx = n.get();
n->fmap_[f->name] = std::move(vm);
}
n->entry_func_ = funcs[0]->name;
return runtime::Module(n);
}
private:
// entry function.
std::string entry_func_;
// internal function map
std::unordered_map<std::string, StackVM> fmap_;
};
TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = StackVMModuleNode::Build(args[0]);
});
} // namespace codegen
} // namespace tvm
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file codegen_stack_vm.cc * \file codegen_stackvm.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <limits> #include <limits>
#include "./codegen_stack_vm.h" #include "./codegen_stackvm.h"
#include "../../runtime/stackvm/stackvm_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -19,6 +20,7 @@ StackVM CodeGenStackVM::Compile(LoweredFunc f) { ...@@ -19,6 +20,7 @@ StackVM CodeGenStackVM::Compile(LoweredFunc f) {
CHECK_EQ(static_cast<size_t>(vid), i); CHECK_EQ(static_cast<size_t>(vid), i);
} }
this->Push(f->body); this->Push(f->body);
vm_.InitCache();
return std::move(vm_); return std::move(vm_);
} }
...@@ -486,5 +488,22 @@ void CodeGenStackVM::VisitExpr_(const Let *op) { ...@@ -486,5 +488,22 @@ void CodeGenStackVM::VisitExpr_(const Let *op) {
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid)); this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
this->Push(op->body); this->Push(op->body);
} }
runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) {
CHECK_NE(funcs.size(), 0U);
std::unordered_map<std::string, StackVM> fmap;
for (LoweredFunc f : funcs) {
StackVM vm = codegen::CodeGenStackVM().Compile(f);
CHECK(!fmap.count(f->name))
<< "Function name " << f->name << "already exist in list";
fmap[f->name] = std::move(vm);
}
return runtime::StackVMModuleCreate(fmap, funcs[0]->name);
}
TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildStackVM(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file codegen_stack_vm.h * \file codegen_stack_vm.h
* \brief Codegen into Simple Stack VM. * \brief Codegen into Simple Stack VM.
*/ */
#ifndef TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_ #ifndef TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_ #define TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "./stack_vm.h" #include "../../runtime/stackvm/stackvm.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace ir;
using runtime::StackVM;
/*! /*!
* \brief A base class to generate a stack VM. * \brief A base class to generate a stack VM.
* This module is used to generate host wrapper * This module is used to generate host wrapper
...@@ -145,4 +147,4 @@ class CodeGenStackVM ...@@ -145,4 +147,4 @@ class CodeGenStackVM
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_ #endif // TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* Implementation stack VM. * Implementation stack VM.
* \file stack_vm.cc * \file stackvm.cc
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/ir.h> #include <tvm/runtime/util.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include "./stack_vm.h" #include <algorithm>
#include "./stackvm.h"
namespace tvm { namespace tvm {
namespace codegen { namespace runtime {
typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore; typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;
...@@ -172,28 +173,64 @@ std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) ...@@ -172,28 +173,64 @@ std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*)
return os; return os;
} }
void StackVM::operator()(const runtime::TVMArgs& args) const { void StackVM::Run(const runtime::TVMArgs& args,
runtime::ModuleNode* mod_ctx) const {
StackVM::State* s = StackVM::ThreadLocalState(); StackVM::State* s = StackVM::ThreadLocalState();
if (s->heap.size() < heap_size) {
s->heap.resize(heap_size);
}
s->sp = 0; s->sp = 0;
s->pc = 0; s->pc = 0;
if (s->heap.size() < this->heap_size) { s->mod_ctx = mod_ctx;
s->heap.resize(this->heap_size);
}
s->heap[0].v_handle = (void*)args.values; // NOLINT(*) s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
s->heap[2].v_int64 = args.num_args; s->heap[2].v_int64 = args.num_args;
this->Run(s); this->Run(s);
} }
void StackVM::InitCache() {
extern_func_cache_.clear();
extern_func_cache_.resize(
extern_func_name.size(), PackedFunc(nullptr));
}
void StackVM::Save(dmlc::Stream* strm) const {
// to be endian invariant.
std::vector<int32_t> code_copy(code.size());
std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) {
return c.v_int;
});
strm->Write(code_copy);
strm->Write(str_data);
strm->Write(extern_func_name);
strm->Write(heap_id_name);
strm->Write(heap_size);
strm->Write(stack_size);
}
bool StackVM::Load(dmlc::Stream* strm) {
// to be endian invariant.
std::vector<int32_t> code_copy;
if (!strm->Read(&code_copy)) return false;
code.resize(code_copy.size());
std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) {
Code code; code.v_int = v; return code;
});
if (!strm->Read(&str_data)) return false;
if (!strm->Read(&extern_func_name)) return false;
if (!strm->Read(&heap_id_name)) return false;
if (!strm->Read(&heap_size)) return false;
if (!strm->Read(&stack_size)) return false;
this->InitCache();
return true;
}
void StackVM::Run(State* s) const { void StackVM::Run(State* s) const {
int64_t sp = s->sp; int64_t sp = s->sp;
int64_t pc = s->pc; int64_t pc = s->pc;
int64_t alloca_sp = s->sp; int64_t alloca_sp = s->sp;
std::vector<TVMValue>& stack = s->stack; std::vector<TVMValue>& stack = s->stack;
std::vector<TVMValue>& heap = s->heap; std::vector<TVMValue>& heap = s->heap;
s->extern_func.clear();
s->extern_func.resize(extern_func_name.size());
if (stack.size() < stack_size) { if (stack.size() < stack_size) {
stack.resize(stack_size); stack.resize(stack_size);
} }
...@@ -488,17 +525,19 @@ void StackVM::Run(State* s) const { ...@@ -488,17 +525,19 @@ void StackVM::Run(State* s) const {
} }
const PackedFunc& StackVM::GetExtern(State* s, int fid) const { const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
PackedFunc& f = s->extern_func[fid]; CHECK_LT(static_cast<size_t>(fid), extern_func_cache_.size());
// allow race write in this, since write is idempotent
PackedFunc& f = extern_func_cache_[fid];
if (f == nullptr) { if (f == nullptr) {
CHECK(mod_ctx != nullptr) CHECK(s->mod_ctx != nullptr)
<< "No local context is set in stackvm"; << "No local context is set in stackvm";
const PackedFunc* pf = mod_ctx->GetFuncFromEnv(extern_func_name[fid]); CHECK(s->mod_ctx != nullptr);
const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
CHECK(pf != nullptr); CHECK(pf != nullptr);
f = *pf; f = *pf;
CHECK(f != nullptr);
} }
return f; return f;
} }
} // namespace codegen } // namespace runtime
} // namespace tvm } // namespace tvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file stack_vm.h * \file stackvm.h
* \brief A simple stack-based virtual machine. * \brief A simple stack-based virtual machine.
* *
* This can be used to interepret host side code * This can be used to interepret host side code
* to setup calls into device functions * to setup calls into device functions
* when only Runtime compilation for device is available(via NVRTC or OpenCL). * when only Runtime compilation for device is available(via NVRTC or OpenCL).
*/ */
#ifndef TVM_CODEGEN_STACK_VM_STACK_VM_H_ #ifndef TVM_RUNTIME_STACKVM_STACKVM_H_
#define TVM_CODEGEN_STACK_VM_STACK_VM_H_ #define TVM_RUNTIME_STACKVM_STACKVM_H_
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/packed_func_ext.h>
#include <string> #include <string>
#include <vector> #include <vector>
namespace tvm { namespace tvm {
namespace codegen { namespace runtime {
using runtime::operator<<; using runtime::operator<<;
/*! /*!
* \brief A simple stack-based virtual machine. * \brief A simple stack-based virtual machine program.
*/ */
class StackVM { class StackVM {
public: public:
/*! /*!
* \brief Invoke the StackVM as PackedFunc * \brief Invoke the StackVM program.
* \param args The arguments to the StackVM. * \param args The arguments to the StackVM.
* \param mod_ctx The module context used in running.
*/ */
void operator()(const TVMArgs& args) const; void Run(const TVMArgs& args, runtime::ModuleNode* mod_ctx) const;
/*! /*!
* \brief The opcode of stack vm * \brief The opcode of stack vm
* \note Notation * \note Notation
...@@ -276,21 +276,25 @@ class StackVM { ...@@ -276,21 +276,25 @@ class StackVM {
std::vector<TVMValue> stack; std::vector<TVMValue> stack;
/*! \brief The global heap space */ /*! \brief The global heap space */
std::vector<TVMValue> heap; std::vector<TVMValue> heap;
/*! \brief extern functions */
std::vector<PackedFunc> extern_func;
/*! \brief stack pointer */ /*! \brief stack pointer */
int64_t sp{0}; int64_t sp{0};
/*! \brief program counter */ /*! \brief program counter */
int64_t pc{0}; int64_t pc{0};
/*! \brief The current module context of stackvm */
runtime::ModuleNode* mod_ctx{nullptr};
}; };
/*! \brief The external function entries. */ /*! \brief Initialize local cache*/
struct ExternFuncEntry { void InitCache();
std::string name; /*!
runtime::PackedFunc func; * \brief Save stackvm program to an output stream
}; * \param strm The output stream
*/
/*! \brief execute the stack vm with given state */ void Save(dmlc::Stream* strm) const;
void Run(State* state) const; /*!
* \brief Load stackvm program from output stream
* \param strm The output stream
*/
bool Load(dmlc::Stream* strm);
/*! /*!
* \brief Print instruction at location pc * \brief Print instruction at location pc
* \param os The ostream * \param os The ostream
...@@ -300,12 +304,11 @@ class StackVM { ...@@ -300,12 +304,11 @@ class StackVM {
int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*) int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*)
/*! \brief Get thread local state of the stack VM */ /*! \brief Get thread local state of the stack VM */
static State* ThreadLocalState(); static State* ThreadLocalState();
// The code below are programs
/*! \brief The instructions */ /*! \brief The instructions */
std::vector<Code> code; std::vector<Code> code;
/*! \brief constant error messages */ /*! \brief constant error messages */
std::vector<std::string> str_data; std::vector<std::string> str_data;
/*! \brief The current module context of stackvm */
runtime::ModuleNode* mod_ctx{nullptr};
/*! \brief Extern functions */ /*! \brief Extern functions */
std::vector<std::string> extern_func_name; std::vector<std::string> extern_func_name;
/*! \brief name of each heap id */ /*! \brief name of each heap id */
...@@ -385,10 +388,18 @@ class StackVM { ...@@ -385,10 +388,18 @@ class StackVM {
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*) friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
private: private:
// execute the stack vm with given state
void Run(State* state) const;
// get extern function. // get extern function.
const PackedFunc& GetExtern(State* s, int fid) const; const PackedFunc& GetExtern(State* s, int fid) const;
// cached extern function
mutable std::vector<PackedFunc> extern_func_cache_;
}; };
} // namespace codegen } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_CODEGEN_STACK_VM_STACK_VM_H_
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::StackVM, true);
}
#endif // TVM_RUNTIME_STACKVM_STACKVM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file stackvm_module.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <dmlc/memory_io.h>
#include "./stackvm_module.h"
#include "../file_util.h"
#include "../module_util.h"
namespace tvm {
namespace runtime {
class StackVMModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const {
return "stackvm";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
return GetFunction(entry_func_, sptr_to_self);
}
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const StackVM& vm = it->second;
// capture sptr_to_self to keep module node alive.
return PackedFunc([vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
vm.Run(args, this);
});
}
std::string GetSource(const std::string& format) final {
std::ostringstream os;
for (const auto& kv : fmap_) {
os << "Function: " << kv.first << '\n';
os << kv.second;
}
return os.str();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string data, mblob;
dmlc::MemoryStringStream writer(&data);
dmlc::Stream* strm = &writer;
strm->Write(fmap_);
strm->Write(entry_func_);
// also save imports
uint64_t num_imports = static_cast<uint64_t>(imports_.size());
strm->Write(num_imports);
for (runtime::Module im : imports_) {
CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
strm->Write(tkey);
LOG(INFO) << "save " << tkey;
im->SaveToBinary(strm);
LOG(INFO) << "FInish save " << tkey;
}
SaveBinaryToFile(file_name, data);
}
static Module Create(std::unordered_map<std::string, StackVM> fmap,
std::string entry_func) {
std::shared_ptr<StackVMModuleNode> n =
std::make_shared<StackVMModuleNode>();
n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func);
return Module(n);
}
static Module Load(dmlc::Stream* strm) {
std::unordered_map<std::string, StackVM> fmap;
std::string entry_func, data;
strm->Read(&fmap);
strm->Read(&entry_func);
std::shared_ptr<StackVMModuleNode> n =
std::make_shared<StackVMModuleNode>();
n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func);
uint64_t num_imports;
strm->Read(&num_imports);
for (uint64_t i = 0; i < num_imports; ++i) {
std::string tkey;
CHECK(strm->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(strm));
n->imports_.emplace_back(std::move(m));
}
return Module(n);
}
static Module LoadFromFile(std::string file_name,
std::string format) {
std::string data;
LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
return Load(&reader);
}
private:
// internal function map
std::unordered_map<std::string, StackVM> fmap_;
// entry function.
std::string entry_func_;
};
Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
std::string entry_func) {
return StackVMModuleNode::Create(fmap, entry_func);
}
TVM_REGISTER_GLOBAL("module.loadfile_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = StackVMModuleNode::LoadFromFile(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file stackvm_module.h
* \brief StackVM module
*/
#ifndef TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_
#define TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_
#include <tvm/runtime/packed_func.h>
#include <string>
#include "./stackvm.h"
namespace tvm {
namespace runtime {
/*!
* \brief create a stackvm module
*
* \param fmap The map from name to function
* \param entry_func The entry function name.
* \return The created module
*/
Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
std::string entry_func);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_
...@@ -109,11 +109,25 @@ def test_device_module_dump(): ...@@ -109,11 +109,25 @@ def test_device_module_dump():
f2[name](a, b) f2[name](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda") def check_stackvm(device):
check_device("vulkan") ctx = tvm.context(device, 0)
check_device("opencl") if not ctx.exist:
check_device("metal") print("Skip because %s is not enabled" % device)
return
temp = util.tempdir()
name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "stackvm", name=name)
path_dso = temp.relpath("dev_lib.stackvm")
#f.export_library(path_dso)
#f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
for device in ["cuda", "vulkan", "opencl", "metal"]:
check_device(device)
check_stackvm(device)
def test_combine_module_llvm(): def test_combine_module_llvm():
"""Test combine multiple module into one shared lib.""" """Test combine multiple module into one shared lib."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment