Commit 9037a4c2 by Tianqi Chen Committed by GitHub

[RUNTIME] Enable injection of some core runtime functions to avoid dynamic lookup (#260)

parent 6196cd50
...@@ -154,12 +154,12 @@ verilog: $(VER_LIBS) ...@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules. # Special rules for LLVM related modules.
build/codegen/llvm/%.o: src/codegen/llvm/%.cc build/codegen/llvm/%.o: src/codegen/llvm/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/codegen/llvm/$*.o $< >build/codegen/llvm/$*.d
$(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@
build/runtime/metal/%.o: src/runtime/metal/%.mm build/runtime/metal/%.o: src/runtime/metal/%.mm
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/%.o: src/%.cc build/%.o: src/%.cc
...@@ -199,7 +199,7 @@ pylint: ...@@ -199,7 +199,7 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
jnilint: jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint lint: cpplint pylint jnilint
......
...@@ -29,7 +29,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) { ...@@ -29,7 +29,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void CodeGenLLVM::Init(const std::string& module_name, void CodeGenLLVM::Init(const std::string& module_name,
llvm::TargetMachine* tm, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx, llvm::LLVMContext* ctx,
bool system_lib) { bool system_lib,
bool dynamic_lookup) {
InitializeLLVM(); InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
// static_assert(alignof(TVMValue) == alignof(double), "invariant"); // static_assert(alignof(TVMValue) == alignof(double), "invariant");
...@@ -62,7 +63,7 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -62,7 +63,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_->getPointerTo(), t_tvm_shape_index_->getPointerTo(),
t_int64_}); t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_}); t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get( ftype_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false); t_int_, {t_int64_, t_int64_, t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx)); md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ = md_very_likely_branch_ =
...@@ -70,45 +71,56 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -70,45 +71,56 @@ void CodeGenLLVM::Init(const std::string& module_name,
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa"); md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode( md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode(
"alias_set", md_tbaa_root_); "alias_set", md_tbaa_root_);
md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode(
"ctx_ptr", md_tbaa_root_);
} }
ctx_ = ctx; ctx_ = ctx;
// initialize modules // initialize Modules and function type
module_.reset(new llvm::Module(module_name, *ctx)); module_.reset(new llvm::Module(module_name, *ctx));
// initialize TVM runtime API ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, {
f_tvm_func_call_ = llvm::Function::Create( t_tvm_func_handle_,
llvm::FunctionType::get(t_int_, { t_tvm_value_->getPointerTo(),
t_tvm_func_handle_, t_int_->getPointerTo(),
t_tvm_value_->getPointerTo(), t_int_,
t_int_->getPointerTo(), t_tvm_value_->getPointerTo(),
t_int_, t_int_->getPointerTo()}, false);
t_tvm_value_->getPointerTo(), ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, {
t_int_->getPointerTo()}, false), t_void_p_,
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); t_char_->getPointerTo(),
f_tvm_get_func_from_env_ = llvm::Function::Create( t_tvm_func_handle_->getPointerTo()}, false);
ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
t_void_, {t_char_->getPointerTo()}, false);
ftype_tvm_parallel_for_ =
llvm::FunctionType::get(t_int_, { llvm::FunctionType::get(t_int_, {
t_void_p_, t_int64_, t_int64_, ftype_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
t_char_->getPointerTo(), , false);
t_tvm_func_handle_->getPointerTo()}, false), // initialize TVM runtime API
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
if (system_lib) { if (system_lib) {
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_ = llvm::Function::Create( f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false), llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get()); llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
} else { } else {
f_tvm_register_system_symbol_ = nullptr; f_tvm_register_system_symbol_ = nullptr;
} }
if (dynamic_lookup || system_lib) {
f_tvm_func_call_ = llvm::Function::Create(
ftype_tvm_func_call_,
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_get_func_from_env_ = llvm::Function::Create(
ftype_tvm_get_func_from_env_,
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_,
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create(
ftype_tvm_parallel_for_,
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
}
this->InitTarget(tm); this->InitTarget(tm);
// initialize builder // initialize builder
builder_.reset(new IRBuilder(*ctx)); builder_.reset(new IRBuilder(*ctx));
this->InitGlobalContext(); this->InitGlobalContext(dynamic_lookup);
} }
void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
...@@ -131,17 +143,48 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -131,17 +143,48 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
} }
} }
void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable( llvm::GlobalVariable* CodeGenLLVM::InitContextPtr(
*module_, t_void_p_, false, llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, p_type, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0, llvm::GlobalValue::LinkOnceAnyLinkage, 0,
tvm::runtime::symbol::tvm_module_ctx); name);
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_)); gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_)); gv->setInitializer(llvm::Constant::getNullValue(p_type));
return gv;
}
llvm::Value* CodeGenLLVM::GetContextPtr(llvm::GlobalVariable* gv) {
CHECK(gv != nullptr);
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
faddr->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
return faddr;
}
void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
// Module context
gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx);
// Register back the locations.
if (f_tvm_register_system_symbol_ != nullptr) { if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back( export_system_symbols_.emplace_back(
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_)); std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
} else {
if (!dynamic_lookup) {
gv_tvm_func_call_ = InitContextPtr(
ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
gv_tvm_get_func_from_env_ = InitContextPtr(
ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
gv_tvm_api_set_last_error_ = InitContextPtr(
ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
gv_tvm_parallel_for_ = InitContextPtr(
ftype_tvm_parallel_for_->getPointerTo(), "__TVMBackendParallelFor");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
}
} }
} }
...@@ -528,9 +571,13 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { ...@@ -528,9 +571,13 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
// Initialize the handle if needed. // Initialize the handle if needed.
builder_->SetInsertPoint(init_block); builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_); llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_); llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
gv_mod_ctx_, gv_mod_ctx_->getAlignment());
ctx->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
llvm::Value* retcode = builder_->CreateCall( llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out}); RuntimeTVMGetFuncFromEnv(), {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode); init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
...@@ -565,7 +612,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) { ...@@ -565,7 +612,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
Int(32), stack_tcode, ConstInt32(end)); Int(32), stack_tcode, ConstInt32(end));
CheckCallSuccess( CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
f_tvm_func_call_, RuntimeTVMFuncCall(),
{handle, arg_value, arg_tcode, ConstInt32(nargs), {handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, ret_tcode})); ret_value, ret_tcode}));
Type r_type = op->type; Type r_type = op->type;
...@@ -584,17 +631,28 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { ...@@ -584,17 +631,28 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
arg_values[i] = MakeValue(op->args[i]); arg_values[i] = MakeValue(op->args[i]);
} }
if (op->type.is_scalar()) { if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name); std::vector<llvm::Type*> arg_types;
if (f == nullptr) { for (llvm::Value* v : arg_values) {
std::vector<llvm::Type*> arg_types; arg_types.push_back(v->getType());
for (llvm::Value* v : arg_values) { }
arg_types.push_back(v->getType()); llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->type), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
if (it->second == nullptr) {
gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name);
it = gv_func_map_.find(op->name);
} }
f = llvm::Function::Create( return builder_->CreateCall(GetContextPtr(it->second), arg_values);
llvm::FunctionType::get(LLVMType(op->type), arg_types, false), } else {
llvm::Function::ExternalLinkage, op->name, module_.get()); llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
return builder_->CreateCall(f, arg_values);
} }
return builder_->CreateCall(f, arg_values);
} else { } else {
llvm::Function* f = module_->getFunction(op->name); llvm::Function* f = module_->getFunction(op->name);
if (f) { if (f) {
...@@ -603,6 +661,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { ...@@ -603,6 +661,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
LOG(FATAL) << "cannot find function " << op->name; LOG(FATAL) << "cannot find function " << op->name;
} }
} }
LOG(FATAL) << "canot reach here";
return nullptr; return nullptr;
} }
...@@ -630,6 +689,24 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall( ...@@ -630,6 +689,24 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall(
return value; return value;
} }
llvm::Value* CodeGenLLVM::RuntimeTVMFuncCall() {
if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_;
return GetContextPtr(gv_tvm_func_call_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMGetFuncFromEnv() {
if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
return GetContextPtr(gv_tvm_get_func_from_env_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
return GetContextPtr(gv_tvm_api_set_last_error_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelFor() {
if (f_tvm_parallel_for_ != nullptr) return f_tvm_parallel_for_;
return GetContextPtr(gv_tvm_parallel_for_);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v); auto it = var_map_.find(v);
CHECK(it != var_map_.end()) CHECK(it != var_map_.end())
...@@ -723,7 +800,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) { ...@@ -723,7 +800,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data // closure data
llvm::StructType* tcdata = llvm::StructType::create(fields); llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create( llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_, ftype_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage, llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get()); "__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure. // allocate and setup the closure, call the closure.
...@@ -737,7 +814,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) { ...@@ -737,7 +814,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
} }
BasicBlock* par_for_end = CheckCallSuccess( BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
f_tvm_parallel_for_, RuntimeTVMParallelFor(),
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)})); {min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function. // Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
...@@ -794,8 +871,9 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, ...@@ -794,8 +871,9 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_->SetInsertPoint(for_end); builder_->SetInsertPoint(for_end);
} }
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) { if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values; std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types; std::vector<llvm::Type*> arg_types;
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
...@@ -808,6 +886,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -808,6 +886,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
module_.get(), id, arg_types); module_.get(), id, arg_types);
return builder_->CreateCall(f, arg_values); return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic("llvm_builtin")) { } else if (op->is_intrinsic("llvm_builtin")) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values; std::vector<llvm::Value*> arg_values;
for (size_t i = 1; i < op->args.size(); ++i) { for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]); llvm::Value* v = MakeValue(op->args[i]);
...@@ -1391,7 +1470,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { ...@@ -1391,7 +1470,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
return CreateCallPacked(op); return CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic || } else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) { op->call_type == Call::PureIntrinsic) {
return CreateIntrinstic(op); return CreateIntrinsic(op);
} else { } else {
CHECK(op->call_type == Call::Extern || CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern); op->call_type == Call::PureExtern);
...@@ -1508,7 +1587,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { ...@@ -1508,7 +1587,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition. // fail condition.
builder_->SetInsertPoint(fail_block); builder_->SetInsertPoint(fail_block);
builder_->CreateCall(f_tvm_api_set_last_error_, {msg}); builder_->CreateCall(RuntimeTVMAPISetLastError(), {msg});
builder_->CreateRet(ConstInt32(-1)); builder_->CreateRet(ConstInt32(-1));
// otherwise set it to be new end. // otherwise set it to be new end.
builder_->SetInsertPoint(end_block); builder_->SetInsertPoint(end_block);
......
...@@ -41,11 +41,14 @@ class CodeGenLLVM : ...@@ -41,11 +41,14 @@ class CodeGenLLVM :
* \param tm Target machine model * \param tm Target machine model
* \param ctx The context. * \param ctx The context.
* \param system_lib Whether to insert system library registration. * \param system_lib Whether to insert system library registration.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/ */
void Init(const std::string& module_name, void Init(const std::string& module_name,
llvm::TargetMachine* tm, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx, llvm::LLVMContext* ctx,
bool system_lib); bool system_lib,
bool dynamic_lookup);
/*! /*!
* \brief Compile and add function f to the current module. * \brief Compile and add function f to the current module.
* \param f The function to be added. * \param f The function to be added.
...@@ -114,7 +117,7 @@ class CodeGenLLVM : ...@@ -114,7 +117,7 @@ class CodeGenLLVM :
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call // create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op); virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call // create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op); virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function. // create call into tvm packed function.
...@@ -178,6 +181,7 @@ class CodeGenLLVM : ...@@ -178,6 +181,7 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr}; llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr}; llvm::MDNode* md_tbaa_alias_set_{nullptr};
llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
// TVM related data types // TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr}; llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr}; llvm::Type* t_tvm_func_handle_{nullptr};
...@@ -185,13 +189,12 @@ class CodeGenLLVM : ...@@ -185,13 +189,12 @@ class CodeGenLLVM :
llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr}; llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr}; llvm::FunctionType* ftype_tvm_par_for_lambda_{nullptr};
// tvm api functions llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::Function* f_tvm_func_call_{nullptr}; llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr}; llvm::FunctionType* ftype_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/ /*! \brief native vector bits of current targetx*/
...@@ -200,6 +203,13 @@ class CodeGenLLVM : ...@@ -200,6 +203,13 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_; std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private: private:
// Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelFor();
// comparison op // comparison op
llvm::Value* GetVarValue(const Variable* v) const; llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
...@@ -232,7 +242,7 @@ class CodeGenLLVM : ...@@ -232,7 +242,7 @@ class CodeGenLLVM :
// return the end block after the check // return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context // Add a function to set global module context
void InitGlobalContext(); void InitGlobalContext(bool dynamic_lookup);
// Add module startup function if needed. // Add module startup function if needed.
void AddStartupFunction(); void AddStartupFunction();
// add alias information. // add alias information.
...@@ -247,8 +257,19 @@ class CodeGenLLVM : ...@@ -247,8 +257,19 @@ class CodeGenLLVM :
bool is_restricted_{true}; bool is_restricted_{true};
// set of var that are not restricted(can alias) // set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_; std::unordered_set<const Variable*> alias_var_set_;
// The local module_context // Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr}; llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_for_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// global to packed function handle // global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_; std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib. // List of symbols to be exported to TVM system lib.
......
...@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) { ...@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
return tm; return tm;
} }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name; entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib); cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib, system_lib);
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
cg->AddFunction(f); cg->AddFunction(f);
} }
...@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Failed to initialize git engine for " << mptr_->getTargetTriple(); << "Failed to initialize git engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false); ee_->runStaticConstructorsDestructors(false);
// setup context address. // setup context address.
void** ctx_addr =
reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx));
// setup context address.
entry_func_ = entry_func_ =
reinterpret_cast<const char*>( reinterpret_cast<const char*>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main)); ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main));
if (ctx_addr != nullptr) { if (void** ctx_addr = reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this; *ctx_addr = this;
} }
runtime::InitContextFunctions([this](const char *name) {
auto value = ee_->getGlobalValueAddress(name);
return value;
});
} }
// The target configuration string // The target configuration string
std::string target_; std::string target_;
......
...@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; << "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else { } else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
} }
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self); return WrapPackedFunc(faddr, sptr_to_self);
...@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode { ...@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
void** ctx_addr = if (auto *ctx_addr =
reinterpret_cast<void**>( reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
GetSymbol(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
*ctx_addr = this; *ctx_addr = this;
} }
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules // Load the imported modules
const char* dev_mblob = const char* dev_mblob =
reinterpret_cast<const char*>( reinterpret_cast<const char*>(
...@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode { ...@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name; << "Failed to load dynamic shared library " << name;
} }
void* GetSymbol(const std::string& name) { void* GetSymbol(const char* name) {
return reinterpret_cast<void*>( return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
} }
void Unload() { void Unload() {
FreeLibrary(lib_handle_); FreeLibrary(lib_handle_);
...@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode { ...@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name; << "Failed to load dynamic shared library " << name;
} }
void* GetSymbol(const std::string& name) { void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name.c_str()); return dlsym(lib_handle_, name);
} }
void Unload() { void Unload() {
dlclose(lib_handle_); dlclose(lib_handle_);
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#define TVM_RUNTIME_MODULE_UTIL_H_ #define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <vector> #include <vector>
extern "C" { extern "C" {
...@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module ...@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
* \param module_list The module list to append to * \param module_list The module list to append to
*/ */
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list); void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template<typename FLookup>
void InitContextFunctions(FLookup flookup) {
if (auto *fp = reinterpret_cast<decltype(&TVMFuncCall)*>
(flookup("__TVMFuncCall"))) {
*fp = TVMFuncCall;
}
if (auto *fp = reinterpret_cast<decltype(&TVMAPISetLastError)*>
(flookup("__TVMAPISetLastError"))) {
*fp = TVMAPISetLastError;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendGetFuncFromEnv)*>
(flookup("__TVMBackendGetFuncFromEnv"))) {
*fp = TVMBackendGetFuncFromEnv;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendAllocWorkspace)*>
(flookup("__TVMBackendAllocWorkspace"))) {
*fp = TVMBackendAllocWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendFreeWorkspace)*>
(flookup("__TVMBackendFreeWorkspace"))) {
*fp = TVMBackendFreeWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendParallelFor)*>
(flookup("__TVMBackendParallelFor"))) {
*fp = TVMBackendParallelFor;
}
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_ #endif // TVM_RUNTIME_MODULE_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment