Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9037a4c2
Commit
9037a4c2
authored
Jul 18, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 18, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Enable injection of some core runtime functions to avoid dynamic lookup (#260)
parent
6196cd50
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
216 additions
and
81 deletions
+216
-81
Makefile
+3
-3
src/codegen/llvm/codegen_llvm.cc
+129
-50
src/codegen/llvm/codegen_llvm.h
+32
-11
src/codegen/llvm/llvm_common.cc
+0
-2
src/codegen/llvm/llvm_module.cc
+7
-6
src/runtime/dso_module.cc
+10
-9
src/runtime/module_util.h
+35
-0
No files found.
Makefile
View file @
9037a4c2
...
@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
...
@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules.
# Special rules for LLVM related modules.
build/codegen/llvm/%.o
:
src/codegen/llvm/%.cc
build/codegen/llvm/%.o
:
src/codegen/llvm/%.cc
@
mkdir
-p
$
(
@D
)
@
mkdir
-p
$
(
@D
)
$(CXX)
$(CFLAGS)
-MM
-MT
build/
$*
.o
$<
>
build
/
$*
.d
$(CXX)
$(CFLAGS)
-MM
-MT
build/
codegen/llvm/
$*
.o
$<
>
build/codegen/llvm
/
$*
.d
$(CXX)
-c
$(CFLAGS)
$(LLVM_CFLAGS)
-c
$<
-o
$@
$(CXX)
-c
$(CFLAGS)
$(LLVM_CFLAGS)
-c
$<
-o
$@
build/runtime/metal/%.o
:
src/runtime/metal/%.mm
build/runtime/metal/%.o
:
src/runtime/metal/%.mm
@
mkdir
-p
$
(
@D
)
@
mkdir
-p
$
(
@D
)
$(CXX)
$(CFLAGS)
-MM
-MT
build/
$*
.o
$<
>
build
/
$*
.d
$(CXX)
$(CFLAGS)
-MM
-MT
build/
runtime/metal/
$*
.o
$<
>
build/runtime/metal
/
$*
.d
$(CXX)
$(OBJCFLAGS)
-c
$(CFLAGS)
-c
$<
-o
$@
$(CXX)
$(OBJCFLAGS)
-c
$(CFLAGS)
-c
$<
-o
$@
build/%.o
:
src/%.cc
build/%.o
:
src/%.cc
...
@@ -199,7 +199,7 @@ pylint:
...
@@ -199,7 +199,7 @@ pylint:
pylint python/tvm
--rcfile
=
$(ROOTDIR)
/tests/lint/pylintrc
pylint python/tvm
--rcfile
=
$(ROOTDIR)
/tests/lint/pylintrc
pylint topi/python/topi
--rcfile
=
$(ROOTDIR)
/tests/lint/pylintrc
pylint topi/python/topi
--rcfile
=
$(ROOTDIR)
/tests/lint/pylintrc
jnilint
:
jnilint
:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint
:
cpplint pylint jnilint
lint
:
cpplint pylint jnilint
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
9037a4c2
...
@@ -29,7 +29,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
...
@@ -29,7 +29,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void
CodeGenLLVM
::
Init
(
const
std
::
string
&
module_name
,
void
CodeGenLLVM
::
Init
(
const
std
::
string
&
module_name
,
llvm
::
TargetMachine
*
tm
,
llvm
::
TargetMachine
*
tm
,
llvm
::
LLVMContext
*
ctx
,
llvm
::
LLVMContext
*
ctx
,
bool
system_lib
)
{
bool
system_lib
,
bool
dynamic_lookup
)
{
InitializeLLVM
();
InitializeLLVM
();
static_assert
(
sizeof
(
TVMValue
)
==
sizeof
(
double
),
"invariant"
);
static_assert
(
sizeof
(
TVMValue
)
==
sizeof
(
double
),
"invariant"
);
// static_assert(alignof(TVMValue) == alignof(double), "invariant");
// static_assert(alignof(TVMValue) == alignof(double), "invariant");
...
@@ -62,7 +63,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
...
@@ -62,7 +63,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_
->
getPointerTo
(),
t_tvm_shape_index_
->
getPointerTo
(),
t_int64_
});
t_int64_
});
t_tvm_value_
=
llvm
::
StructType
::
create
({
t_float64_
});
t_tvm_value_
=
llvm
::
StructType
::
create
({
t_float64_
});
t_f
_tvm_par_for_lambda_
=
llvm
::
FunctionType
::
get
(
ftype
_tvm_par_for_lambda_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int64_
,
t_int64_
,
t_void_p_
},
false
);
t_int_
,
{
t_int64_
,
t_int64_
,
t_void_p_
},
false
);
md_builder_
.
reset
(
new
llvm
::
MDBuilder
(
*
ctx
));
md_builder_
.
reset
(
new
llvm
::
MDBuilder
(
*
ctx
));
md_very_likely_branch_
=
md_very_likely_branch_
=
...
@@ -70,45 +71,56 @@ void CodeGenLLVM::Init(const std::string& module_name,
...
@@ -70,45 +71,56 @@ void CodeGenLLVM::Init(const std::string& module_name,
md_tbaa_root_
=
md_builder_
->
createTBAARoot
(
"tvmtbaa"
);
md_tbaa_root_
=
md_builder_
->
createTBAARoot
(
"tvmtbaa"
);
md_tbaa_alias_set_
=
md_builder_
->
createTBAAScalarTypeNode
(
md_tbaa_alias_set_
=
md_builder_
->
createTBAAScalarTypeNode
(
"alias_set"
,
md_tbaa_root_
);
"alias_set"
,
md_tbaa_root_
);
md_tbaa_ctx_ptr_
=
md_builder_
->
createTBAAScalarTypeNode
(
"ctx_ptr"
,
md_tbaa_root_
);
}
}
ctx_
=
ctx
;
ctx_
=
ctx
;
// initialize
modules
// initialize
Modules and function type
module_
.
reset
(
new
llvm
::
Module
(
module_name
,
*
ctx
));
module_
.
reset
(
new
llvm
::
Module
(
module_name
,
*
ctx
));
// initialize TVM runtime API
ftype_tvm_func_call_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
f_tvm_func_call_
=
llvm
::
Function
::
Create
(
t_tvm_func_handle_
,
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_tvm_value_
->
getPointerTo
(),
t_tvm_func_handle_
,
t_int_
->
getPointerTo
(),
t_tvm_value_
->
getPointerTo
(),
t_int_
,
t_int_
->
getPointerTo
(),
t_tvm_value_
->
getPointerTo
(),
t_int_
,
t_int_
->
getPointerTo
()},
false
);
t_tvm_value_
->
getPointerTo
(),
ftype_tvm_get_func_from_env_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int_
->
getPointerTo
()},
false
),
t_void_p_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMFuncCall"
,
module_
.
get
());
t_char_
->
getPointerTo
(),
f_tvm_get_func_from_env_
=
llvm
::
Function
::
Create
(
t_tvm_func_handle_
->
getPointerTo
()},
false
);
ftype_tvm_api_set_last_error_
=
llvm
::
FunctionType
::
get
(
t_void_
,
{
t_char_
->
getPointerTo
()},
false
);
ftype_tvm_parallel_for_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_void_p_
,
t_int64_
,
t_int64_
,
ftype_tvm_par_for_lambda_
->
getPointerTo
(),
t_void_p_
}
t_char_
->
getPointerTo
(),
,
false
);
t_tvm_func_handle_
->
getPointerTo
()},
false
),
// initialize TVM runtime API
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendGetFuncFromEnv"
,
module_
.
get
());
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_void_
,
{
t_char_
->
getPointerTo
()},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
f_tvm_parallel_for_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int64_
,
t_int64_
,
t_f_tvm_par_for_lambda_
->
getPointerTo
(),
t_void_p_
}
,
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelFor"
,
module_
.
get
());
if
(
system_lib
)
{
if
(
system_lib
)
{
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_
=
llvm
::
Function
::
Create
(
f_tvm_register_system_symbol_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_char_
->
getPointerTo
(),
t_void_p_
},
false
),
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_char_
->
getPointerTo
(),
t_void_p_
},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendRegisterSystemLibSymbol"
,
module_
.
get
());
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendRegisterSystemLibSymbol"
,
module_
.
get
());
}
else
{
}
else
{
f_tvm_register_system_symbol_
=
nullptr
;
f_tvm_register_system_symbol_
=
nullptr
;
}
}
if
(
dynamic_lookup
||
system_lib
)
{
f_tvm_func_call_
=
llvm
::
Function
::
Create
(
ftype_tvm_func_call_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMFuncCall"
,
module_
.
get
());
f_tvm_get_func_from_env_
=
llvm
::
Function
::
Create
(
ftype_tvm_get_func_from_env_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendGetFuncFromEnv"
,
module_
.
get
());
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
ftype_tvm_api_set_last_error_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
f_tvm_parallel_for_
=
llvm
::
Function
::
Create
(
ftype_tvm_parallel_for_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelFor"
,
module_
.
get
());
}
this
->
InitTarget
(
tm
);
this
->
InitTarget
(
tm
);
// initialize builder
// initialize builder
builder_
.
reset
(
new
IRBuilder
(
*
ctx
));
builder_
.
reset
(
new
IRBuilder
(
*
ctx
));
this
->
InitGlobalContext
();
this
->
InitGlobalContext
(
dynamic_lookup
);
}
}
void
CodeGenLLVM
::
InitTarget
(
llvm
::
TargetMachine
*
tm
)
{
void
CodeGenLLVM
::
InitTarget
(
llvm
::
TargetMachine
*
tm
)
{
...
@@ -131,17 +143,48 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
...
@@ -131,17 +143,48 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
}
}
}
}
void
CodeGenLLVM
::
InitGlobalContext
()
{
gv_mod_ctx_
=
new
llvm
::
GlobalVariable
(
llvm
::
GlobalVariable
*
CodeGenLLVM
::
InitContextPtr
(
*
module_
,
t_void_p_
,
false
,
llvm
::
Type
*
p_type
,
std
::
string
name
)
{
llvm
::
GlobalVariable
*
gv
=
new
llvm
::
GlobalVariable
(
*
module_
,
p_type
,
false
,
llvm
::
GlobalValue
::
LinkOnceAnyLinkage
,
0
,
llvm
::
GlobalValue
::
LinkOnceAnyLinkage
,
0
,
tvm
::
runtime
::
symbol
::
tvm_module_ctx
);
name
);
gv_mod_ctx_
->
setAlignment
(
data_layout_
->
getTypeAllocSize
(
t_void_p_
));
gv
->
setAlignment
(
data_layout_
->
getTypeAllocSize
(
p_type
));
gv_mod_ctx_
->
setInitializer
(
llvm
::
Constant
::
getNullValue
(
t_void_p_
));
gv
->
setInitializer
(
llvm
::
Constant
::
getNullValue
(
p_type
));
return
gv
;
}
llvm
::
Value
*
CodeGenLLVM
::
GetContextPtr
(
llvm
::
GlobalVariable
*
gv
)
{
CHECK
(
gv
!=
nullptr
);
llvm
::
LoadInst
*
faddr
=
builder_
->
CreateAlignedLoad
(
gv
,
gv
->
getAlignment
());
faddr
->
setMetadata
(
"tbaa"
,
md_builder_
->
createTBAAStructTagNode
(
md_tbaa_ctx_ptr_
,
md_tbaa_ctx_ptr_
,
0
));
return
faddr
;
}
void
CodeGenLLVM
::
InitGlobalContext
(
bool
dynamic_lookup
)
{
// Module context
gv_mod_ctx_
=
InitContextPtr
(
t_void_p_
,
tvm
::
runtime
::
symbol
::
tvm_module_ctx
);
// Register back the locations.
if
(
f_tvm_register_system_symbol_
!=
nullptr
)
{
if
(
f_tvm_register_system_symbol_
!=
nullptr
)
{
export_system_symbols_
.
emplace_back
(
export_system_symbols_
.
emplace_back
(
std
::
make_pair
(
tvm
::
runtime
::
symbol
::
tvm_module_ctx
,
gv_mod_ctx_
));
std
::
make_pair
(
tvm
::
runtime
::
symbol
::
tvm_module_ctx
,
gv_mod_ctx_
));
}
else
{
if
(
!
dynamic_lookup
)
{
gv_tvm_func_call_
=
InitContextPtr
(
ftype_tvm_func_call_
->
getPointerTo
(),
"__TVMFuncCall"
);
gv_tvm_get_func_from_env_
=
InitContextPtr
(
ftype_tvm_get_func_from_env_
->
getPointerTo
(),
"__TVMBackendGetFuncFromEnv"
);
gv_tvm_api_set_last_error_
=
InitContextPtr
(
ftype_tvm_api_set_last_error_
->
getPointerTo
(),
"__TVMAPISetLastError"
);
gv_tvm_parallel_for_
=
InitContextPtr
(
ftype_tvm_parallel_for_
->
getPointerTo
(),
"__TVMBackendParallelFor"
);
// Mark as context functions
gv_func_map_
[
"TVMBackendAllocWorkspace"
]
=
nullptr
;
gv_func_map_
[
"TVMBackendFreeWorkspace"
]
=
nullptr
;
}
}
}
}
}
...
@@ -528,9 +571,13 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
...
@@ -528,9 +571,13 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
// Initialize the handle if needed.
// Initialize the handle if needed.
builder_
->
SetInsertPoint
(
init_block
);
builder_
->
SetInsertPoint
(
init_block
);
llvm
::
Value
*
out
=
builder_
->
CreateAlloca
(
t_tvm_func_handle_
);
llvm
::
Value
*
out
=
builder_
->
CreateAlloca
(
t_tvm_func_handle_
);
llvm
::
Value
*
ctx
=
builder_
->
CreateLoad
(
gv_mod_ctx_
);
llvm
::
LoadInst
*
ctx
=
builder_
->
CreateAlignedLoad
(
gv_mod_ctx_
,
gv_mod_ctx_
->
getAlignment
());
ctx
->
setMetadata
(
"tbaa"
,
md_builder_
->
createTBAAStructTagNode
(
md_tbaa_ctx_ptr_
,
md_tbaa_ctx_ptr_
,
0
));
llvm
::
Value
*
retcode
=
builder_
->
CreateCall
(
llvm
::
Value
*
retcode
=
builder_
->
CreateCall
(
f_tvm_get_func_from_env_
,
{
ctx
,
GetConstString
(
fname
),
out
});
RuntimeTVMGetFuncFromEnv
()
,
{
ctx
,
GetConstString
(
fname
),
out
});
init_block
=
CheckCallSuccess
(
retcode
);
init_block
=
CheckCallSuccess
(
retcode
);
llvm
::
Value
*
loaded_handle
=
builder_
->
CreateAlignedLoad
(
out
,
align
);
llvm
::
Value
*
loaded_handle
=
builder_
->
CreateAlignedLoad
(
out
,
align
);
builder_
->
CreateBr
(
end_block
);
builder_
->
CreateBr
(
end_block
);
...
@@ -565,7 +612,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
...
@@ -565,7 +612,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
Int
(
32
),
stack_tcode
,
ConstInt32
(
end
));
Int
(
32
),
stack_tcode
,
ConstInt32
(
end
));
CheckCallSuccess
(
CheckCallSuccess
(
builder_
->
CreateCall
(
builder_
->
CreateCall
(
f_tvm_func_call_
,
RuntimeTVMFuncCall
()
,
{
handle
,
arg_value
,
arg_tcode
,
ConstInt32
(
nargs
),
{
handle
,
arg_value
,
arg_tcode
,
ConstInt32
(
nargs
),
ret_value
,
ret_tcode
}));
ret_value
,
ret_tcode
}));
Type
r_type
=
op
->
type
;
Type
r_type
=
op
->
type
;
...
@@ -584,17 +631,28 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
...
@@ -584,17 +631,28 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
arg_values
[
i
]
=
MakeValue
(
op
->
args
[
i
]);
arg_values
[
i
]
=
MakeValue
(
op
->
args
[
i
]);
}
}
if
(
op
->
type
.
is_scalar
())
{
if
(
op
->
type
.
is_scalar
())
{
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
std
::
vector
<
llvm
::
Type
*>
arg_types
;
if
(
f
==
nullptr
)
{
for
(
llvm
::
Value
*
v
:
arg_values
)
{
std
::
vector
<
llvm
::
Type
*>
arg_types
;
arg_types
.
push_back
(
v
->
getType
());
for
(
llvm
::
Value
*
v
:
arg_values
)
{
}
arg_types
.
push_back
(
v
->
getType
());
llvm
::
FunctionType
*
ftype
=
llvm
::
FunctionType
::
get
(
LLVMType
(
op
->
type
),
arg_types
,
false
);
// Check if it is available in global function table as injected function.
auto
it
=
gv_func_map_
.
find
(
op
->
name
);
if
(
it
!=
gv_func_map_
.
end
())
{
if
(
it
->
second
==
nullptr
)
{
gv_func_map_
[
op
->
name
]
=
InitContextPtr
(
ftype
->
getPointerTo
(),
"__"
+
op
->
name
);
it
=
gv_func_map_
.
find
(
op
->
name
);
}
}
f
=
llvm
::
Function
::
Create
(
return
builder_
->
CreateCall
(
GetContextPtr
(
it
->
second
),
arg_values
);
llvm
::
FunctionType
::
get
(
LLVMType
(
op
->
type
),
arg_types
,
false
),
}
else
{
llvm
::
Function
::
ExternalLinkage
,
op
->
name
,
module_
.
get
());
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
if
(
f
==
nullptr
)
{
f
=
llvm
::
Function
::
Create
(
ftype
,
llvm
::
Function
::
ExternalLinkage
,
op
->
name
,
module_
.
get
());
}
return
builder_
->
CreateCall
(
f
,
arg_values
);
}
}
return
builder_
->
CreateCall
(
f
,
arg_values
);
}
else
{
}
else
{
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
if
(
f
)
{
if
(
f
)
{
...
@@ -603,6 +661,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
...
@@ -603,6 +661,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
LOG
(
FATAL
)
<<
"cannot find function "
<<
op
->
name
;
LOG
(
FATAL
)
<<
"cannot find function "
<<
op
->
name
;
}
}
}
}
LOG
(
FATAL
)
<<
"canot reach here"
;
return
nullptr
;
return
nullptr
;
}
}
...
@@ -630,6 +689,24 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall(
...
@@ -630,6 +689,24 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall(
return
value
;
return
value
;
}
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMFuncCall
()
{
if
(
f_tvm_func_call_
!=
nullptr
)
return
f_tvm_func_call_
;
return
GetContextPtr
(
gv_tvm_func_call_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMGetFuncFromEnv
()
{
if
(
f_tvm_get_func_from_env_
!=
nullptr
)
return
f_tvm_get_func_from_env_
;
return
GetContextPtr
(
gv_tvm_get_func_from_env_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMAPISetLastError
()
{
if
(
f_tvm_api_set_last_error_
!=
nullptr
)
return
f_tvm_api_set_last_error_
;
return
GetContextPtr
(
gv_tvm_api_set_last_error_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMParallelFor
()
{
if
(
f_tvm_parallel_for_
!=
nullptr
)
return
f_tvm_parallel_for_
;
return
GetContextPtr
(
gv_tvm_parallel_for_
);
}
llvm
::
Value
*
CodeGenLLVM
::
GetVarValue
(
const
Variable
*
v
)
const
{
llvm
::
Value
*
CodeGenLLVM
::
GetVarValue
(
const
Variable
*
v
)
const
{
auto
it
=
var_map_
.
find
(
v
);
auto
it
=
var_map_
.
find
(
v
);
CHECK
(
it
!=
var_map_
.
end
())
CHECK
(
it
!=
var_map_
.
end
())
...
@@ -723,7 +800,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
...
@@ -723,7 +800,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data
// closure data
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
t_f
_tvm_par_for_lambda_
,
ftype
_tvm_par_for_lambda_
,
llvm
::
Function
::
PrivateLinkage
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_par_for_lambda"
,
module_
.
get
());
"__tvm_par_for_lambda"
,
module_
.
get
());
// allocate and setup the closure, call the closure.
// allocate and setup the closure, call the closure.
...
@@ -737,7 +814,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
...
@@ -737,7 +814,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
}
}
BasicBlock
*
par_for_end
=
CheckCallSuccess
(
BasicBlock
*
par_for_end
=
CheckCallSuccess
(
builder_
->
CreateCall
(
builder_
->
CreateCall
(
f_tvm_parallel_for_
,
RuntimeTVMParallelFor
()
,
{
min
,
extent
,
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
)}));
{
min
,
extent
,
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
)}));
// Setup the closure function.
// Setup the closure function.
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
...
@@ -794,8 +871,9 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
...
@@ -794,8 +871,9 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_
->
SetInsertPoint
(
for_end
);
builder_
->
SetInsertPoint
(
for_end
);
}
}
llvm
::
Value
*
CodeGenLLVM
::
CreateIntrins
t
ic
(
const
Call
*
op
)
{
llvm
::
Value
*
CodeGenLLVM
::
CreateIntrinsic
(
const
Call
*
op
)
{
if
(
op
->
is_intrinsic
(
"llvm_intrin"
))
{
if
(
op
->
is_intrinsic
(
"llvm_intrin"
))
{
CHECK_GE
(
op
->
args
.
size
(),
1U
);
std
::
vector
<
llvm
::
Value
*>
arg_values
;
std
::
vector
<
llvm
::
Value
*>
arg_values
;
std
::
vector
<
llvm
::
Type
*>
arg_types
;
std
::
vector
<
llvm
::
Type
*>
arg_types
;
for
(
size_t
i
=
1
;
i
<
op
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
op
->
args
.
size
();
++
i
)
{
...
@@ -808,6 +886,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
...
@@ -808,6 +886,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
module_
.
get
(),
id
,
arg_types
);
module_
.
get
(),
id
,
arg_types
);
return
builder_
->
CreateCall
(
f
,
arg_values
);
return
builder_
->
CreateCall
(
f
,
arg_values
);
}
else
if
(
op
->
is_intrinsic
(
"llvm_builtin"
))
{
}
else
if
(
op
->
is_intrinsic
(
"llvm_builtin"
))
{
CHECK_GE
(
op
->
args
.
size
(),
1U
);
std
::
vector
<
llvm
::
Value
*>
arg_values
;
std
::
vector
<
llvm
::
Value
*>
arg_values
;
for
(
size_t
i
=
1
;
i
<
op
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
op
->
args
.
size
();
++
i
)
{
llvm
::
Value
*
v
=
MakeValue
(
op
->
args
[
i
]);
llvm
::
Value
*
v
=
MakeValue
(
op
->
args
[
i
]);
...
@@ -1391,7 +1470,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
...
@@ -1391,7 +1470,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
return
CreateCallPacked
(
op
);
return
CreateCallPacked
(
op
);
}
else
if
(
op
->
call_type
==
Call
::
Intrinsic
||
}
else
if
(
op
->
call_type
==
Call
::
Intrinsic
||
op
->
call_type
==
Call
::
PureIntrinsic
)
{
op
->
call_type
==
Call
::
PureIntrinsic
)
{
return
CreateIntrins
t
ic
(
op
);
return
CreateIntrinsic
(
op
);
}
else
{
}
else
{
CHECK
(
op
->
call_type
==
Call
::
Extern
||
CHECK
(
op
->
call_type
==
Call
::
Extern
||
op
->
call_type
==
Call
::
PureExtern
);
op
->
call_type
==
Call
::
PureExtern
);
...
@@ -1508,7 +1587,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
...
@@ -1508,7 +1587,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_
->
CreateCondBr
(
cond
,
end_block
,
fail_block
,
md_very_likely_branch_
);
builder_
->
CreateCondBr
(
cond
,
end_block
,
fail_block
,
md_very_likely_branch_
);
// fail condition.
// fail condition.
builder_
->
SetInsertPoint
(
fail_block
);
builder_
->
SetInsertPoint
(
fail_block
);
builder_
->
CreateCall
(
f_tvm_api_set_last_error_
,
{
msg
});
builder_
->
CreateCall
(
RuntimeTVMAPISetLastError
()
,
{
msg
});
builder_
->
CreateRet
(
ConstInt32
(
-
1
));
builder_
->
CreateRet
(
ConstInt32
(
-
1
));
// otherwise set it to be new end.
// otherwise set it to be new end.
builder_
->
SetInsertPoint
(
end_block
);
builder_
->
SetInsertPoint
(
end_block
);
...
...
src/codegen/llvm/codegen_llvm.h
View file @
9037a4c2
...
@@ -41,11 +41,14 @@ class CodeGenLLVM :
...
@@ -41,11 +41,14 @@ class CodeGenLLVM :
* \param tm Target machine model
* \param tm Target machine model
* \param ctx The context.
* \param ctx The context.
* \param system_lib Whether to insert system library registration.
* \param system_lib Whether to insert system library registration.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/
*/
void
Init
(
const
std
::
string
&
module_name
,
void
Init
(
const
std
::
string
&
module_name
,
llvm
::
TargetMachine
*
tm
,
llvm
::
TargetMachine
*
tm
,
llvm
::
LLVMContext
*
ctx
,
llvm
::
LLVMContext
*
ctx
,
bool
system_lib
);
bool
system_lib
,
bool
dynamic_lookup
);
/*!
/*!
* \brief Compile and add function f to the current module.
* \brief Compile and add function f to the current module.
* \param f The function to be added.
* \param f The function to be added.
...
@@ -114,7 +117,7 @@ class CodeGenLLVM :
...
@@ -114,7 +117,7 @@ class CodeGenLLVM :
void
VisitStmt_
(
const
Evaluate
*
op
)
override
;
void
VisitStmt_
(
const
Evaluate
*
op
)
override
;
void
VisitStmt_
(
const
ProducerConsumer
*
op
)
override
;
void
VisitStmt_
(
const
ProducerConsumer
*
op
)
override
;
// create intrinstic given call
// create intrinstic given call
virtual
llvm
::
Value
*
CreateIntrins
t
ic
(
const
Call
*
op
);
virtual
llvm
::
Value
*
CreateIntrinsic
(
const
Call
*
op
);
// create extern function call
// create extern function call
virtual
llvm
::
Value
*
CreateCallExtern
(
const
Call
*
op
);
virtual
llvm
::
Value
*
CreateCallExtern
(
const
Call
*
op
);
// create call into tvm packed function.
// create call into tvm packed function.
...
@@ -178,6 +181,7 @@ class CodeGenLLVM :
...
@@ -178,6 +181,7 @@ class CodeGenLLVM :
llvm
::
MDNode
*
md_very_likely_branch_
{
nullptr
};
llvm
::
MDNode
*
md_very_likely_branch_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_root_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_root_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_alias_set_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_alias_set_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_ctx_ptr_
{
nullptr
};
// TVM related data types
// TVM related data types
llvm
::
Type
*
t_tvm_shape_index_
{
nullptr
};
llvm
::
Type
*
t_tvm_shape_index_
{
nullptr
};
llvm
::
Type
*
t_tvm_func_handle_
{
nullptr
};
llvm
::
Type
*
t_tvm_func_handle_
{
nullptr
};
...
@@ -185,13 +189,12 @@ class CodeGenLLVM :
...
@@ -185,13 +189,12 @@ class CodeGenLLVM :
llvm
::
StructType
*
t_tvm_type_
{
nullptr
};
llvm
::
StructType
*
t_tvm_type_
{
nullptr
};
llvm
::
StructType
*
t_tvm_array_
{
nullptr
};
llvm
::
StructType
*
t_tvm_array_
{
nullptr
};
llvm
::
StructType
*
t_tvm_value_
{
nullptr
};
llvm
::
StructType
*
t_tvm_value_
{
nullptr
};
llvm
::
FunctionType
*
t_f_tvm_par_for_lambda_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_par_for_lambda_
{
nullptr
};
// tvm api functions
llvm
::
FunctionType
*
ftype_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_for_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_for_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_register_system_symbol_
{
nullptr
};
llvm
::
Function
*
f_tvm_register_system_symbol_
{
nullptr
};
// The acting body
// The acting body
llvm
::
BasicBlock
*
block_
{
nullptr
};
llvm
::
BasicBlock
*
block_
{
nullptr
};
/*! \brief native vector bits of current targetx*/
/*! \brief native vector bits of current targetx*/
...
@@ -200,6 +203,13 @@ class CodeGenLLVM :
...
@@ -200,6 +203,13 @@ class CodeGenLLVM :
std
::
unordered_map
<
const
Variable
*
,
StorageInfo
>
alloc_storage_info_
;
std
::
unordered_map
<
const
Variable
*
,
StorageInfo
>
alloc_storage_info_
;
private
:
private
:
// Get runtime functions
llvm
::
GlobalVariable
*
InitContextPtr
(
llvm
::
Type
*
type
,
std
::
string
name
);
llvm
::
Value
*
GetContextPtr
(
llvm
::
GlobalVariable
*
gv
);
llvm
::
Value
*
RuntimeTVMFuncCall
();
llvm
::
Value
*
RuntimeTVMGetFuncFromEnv
();
llvm
::
Value
*
RuntimeTVMAPISetLastError
();
llvm
::
Value
*
RuntimeTVMParallelFor
();
// comparison op
// comparison op
llvm
::
Value
*
GetVarValue
(
const
Variable
*
v
)
const
;
llvm
::
Value
*
GetVarValue
(
const
Variable
*
v
)
const
;
llvm
::
Value
*
CreateLT
(
Type
t
,
llvm
::
Value
*
a
,
llvm
::
Value
*
b
);
llvm
::
Value
*
CreateLT
(
Type
t
,
llvm
::
Value
*
a
,
llvm
::
Value
*
b
);
...
@@ -232,7 +242,7 @@ class CodeGenLLVM :
...
@@ -232,7 +242,7 @@ class CodeGenLLVM :
// return the end block after the check
// return the end block after the check
llvm
::
BasicBlock
*
CheckCallSuccess
(
llvm
::
Value
*
retcode
);
llvm
::
BasicBlock
*
CheckCallSuccess
(
llvm
::
Value
*
retcode
);
// Add a function to set global module context
// Add a function to set global module context
void
InitGlobalContext
();
void
InitGlobalContext
(
bool
dynamic_lookup
);
// Add module startup function if needed.
// Add module startup function if needed.
void
AddStartupFunction
();
void
AddStartupFunction
();
// add alias information.
// add alias information.
...
@@ -247,8 +257,19 @@ class CodeGenLLVM :
...
@@ -247,8 +257,19 @@ class CodeGenLLVM :
bool
is_restricted_
{
true
};
bool
is_restricted_
{
true
};
// set of var that are not restricted(can alias)
// set of var that are not restricted(can alias)
std
::
unordered_set
<
const
Variable
*>
alias_var_set_
;
std
::
unordered_set
<
const
Variable
*>
alias_var_set_
;
//
The local module_context
//
Context for injection lookup
llvm
::
GlobalVariable
*
gv_mod_ctx_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_mod_ctx_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_func_call_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_get_func_from_env_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_api_set_last_error_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_parallel_for_
{
nullptr
};
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
gv_func_map_
;
// context for direct dynamic lookup
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_for_
{
nullptr
};
llvm
::
Function
*
f_tvm_register_system_symbol_
{
nullptr
};
// global to packed function handle
// global to packed function handle
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
func_handle_map_
;
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
func_handle_map_
;
// List of symbols to be exported to TVM system lib.
// List of symbols to be exported to TVM system lib.
...
...
src/codegen/llvm/llvm_common.cc
View file @
9037a4c2
...
@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
...
@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
return
tm
;
return
tm
;
}
}
}
// namespace codegen
}
// namespace codegen
}
// namespace tvm
}
// namespace tvm
#endif // TVM_LLVM_VERSION
#endif // TVM_LLVM_VERSION
src/codegen/llvm/llvm_module.cc
View file @
9037a4c2
...
@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
...
@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_
=
std
::
make_shared
<
llvm
::
LLVMContext
>
();
ctx_
=
std
::
make_shared
<
llvm
::
LLVMContext
>
();
std
::
unique_ptr
<
CodeGenLLVM
>
cg
=
CodeGenLLVM
::
Create
(
tm_
);
std
::
unique_ptr
<
CodeGenLLVM
>
cg
=
CodeGenLLVM
::
Create
(
tm_
);
entry_func_
=
funcs
[
0
]
->
name
;
entry_func_
=
funcs
[
0
]
->
name
;
cg
->
Init
(
funcs
[
0
]
->
name
,
tm_
,
ctx_
.
get
(),
system_lib
);
cg
->
Init
(
funcs
[
0
]
->
name
,
tm_
,
ctx_
.
get
(),
system_lib
,
system_lib
);
for
(
LoweredFunc
f
:
funcs
)
{
for
(
LoweredFunc
f
:
funcs
)
{
cg
->
AddFunction
(
f
);
cg
->
AddFunction
(
f
);
}
}
...
@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
...
@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<<
"Failed to initialize git engine for "
<<
mptr_
->
getTargetTriple
();
<<
"Failed to initialize git engine for "
<<
mptr_
->
getTargetTriple
();
ee_
->
runStaticConstructorsDestructors
(
false
);
ee_
->
runStaticConstructorsDestructors
(
false
);
// setup context address.
// setup context address.
void
**
ctx_addr
=
reinterpret_cast
<
void
**>
(
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_ctx
));
// setup context address.
entry_func_
=
entry_func_
=
reinterpret_cast
<
const
char
*>
(
reinterpret_cast
<
const
char
*>
(
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_main
));
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_main
));
if
(
ctx_addr
!=
nullptr
)
{
if
(
void
**
ctx_addr
=
reinterpret_cast
<
void
**>
(
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_ctx
)))
{
*
ctx_addr
=
this
;
*
ctx_addr
=
this
;
}
}
runtime
::
InitContextFunctions
([
this
](
const
char
*
name
)
{
auto
value
=
ee_
->
getGlobalValueAddress
(
name
);
return
value
;
});
}
}
// The target configuration string
// The target configuration string
std
::
string
target_
;
std
::
string
target_
;
...
...
src/runtime/dso_module.cc
View file @
9037a4c2
...
@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
...
@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
<<
"Symbol "
<<
runtime
::
symbol
::
tvm_module_main
<<
" is not presented"
;
<<
"Symbol "
<<
runtime
::
symbol
::
tvm_module_main
<<
" is not presented"
;
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
entry_name
));
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
entry_name
));
}
else
{
}
else
{
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
name
));
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
name
.
c_str
()
));
}
}
if
(
faddr
==
nullptr
)
return
PackedFunc
();
if
(
faddr
==
nullptr
)
return
PackedFunc
();
return
WrapPackedFunc
(
faddr
,
sptr_to_self
);
return
WrapPackedFunc
(
faddr
,
sptr_to_self
);
...
@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
...
@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
void
Init
(
const
std
::
string
&
name
)
{
void
Init
(
const
std
::
string
&
name
)
{
Load
(
name
);
Load
(
name
);
void
**
ctx_addr
=
if
(
auto
*
ctx_addr
=
reinterpret_cast
<
void
**>
(
reinterpret_cast
<
void
**>
(
GetSymbol
(
runtime
::
symbol
::
tvm_module_ctx
)))
{
GetSymbol
(
runtime
::
symbol
::
tvm_module_ctx
));
if
(
ctx_addr
!=
nullptr
)
{
*
ctx_addr
=
this
;
*
ctx_addr
=
this
;
}
}
InitContextFunctions
([
this
](
const
char
*
fname
)
{
return
GetSymbol
(
fname
);
});
// Load the imported modules
// Load the imported modules
const
char
*
dev_mblob
=
const
char
*
dev_mblob
=
reinterpret_cast
<
const
char
*>
(
reinterpret_cast
<
const
char
*>
(
...
@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
...
@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
CHECK
(
lib_handle_
!=
nullptr
)
CHECK
(
lib_handle_
!=
nullptr
)
<<
"Failed to load dynamic shared library "
<<
name
;
<<
"Failed to load dynamic shared library "
<<
name
;
}
}
void
*
GetSymbol
(
const
std
::
string
&
name
)
{
void
*
GetSymbol
(
const
char
*
name
)
{
return
reinterpret_cast
<
void
*>
(
return
reinterpret_cast
<
void
*>
(
GetProcAddress
(
lib_handle_
,
(
LPCSTR
)
name
.
c_str
()
));
// NOLINT(*)
GetProcAddress
(
lib_handle_
,
(
LPCSTR
)
name
));
// NOLINT(*)
}
}
void
Unload
()
{
void
Unload
()
{
FreeLibrary
(
lib_handle_
);
FreeLibrary
(
lib_handle_
);
...
@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
...
@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
CHECK
(
lib_handle_
!=
nullptr
)
CHECK
(
lib_handle_
!=
nullptr
)
<<
"Failed to load dynamic shared library "
<<
name
;
<<
"Failed to load dynamic shared library "
<<
name
;
}
}
void
*
GetSymbol
(
const
std
::
string
&
name
)
{
void
*
GetSymbol
(
const
char
*
name
)
{
return
dlsym
(
lib_handle_
,
name
.
c_str
()
);
return
dlsym
(
lib_handle_
,
name
);
}
}
void
Unload
()
{
void
Unload
()
{
dlclose
(
lib_handle_
);
dlclose
(
lib_handle_
);
...
...
src/runtime/module_util.h
View file @
9037a4c2
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#define TVM_RUNTIME_MODULE_UTIL_H_
#define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <vector>
#include <vector>
extern
"C"
{
extern
"C"
{
...
@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
...
@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
* \param module_list The module list to append to
* \param module_list The module list to append to
*/
*/
void
ImportModuleBlob
(
const
char
*
mblob
,
std
::
vector
<
Module
>*
module_list
);
void
ImportModuleBlob
(
const
char
*
mblob
,
std
::
vector
<
Module
>*
module_list
);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template
<
typename
FLookup
>
void
InitContextFunctions
(
FLookup
flookup
)
{
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMFuncCall
)
*>
(
flookup
(
"__TVMFuncCall"
)))
{
*
fp
=
TVMFuncCall
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMAPISetLastError
)
*>
(
flookup
(
"__TVMAPISetLastError"
)))
{
*
fp
=
TVMAPISetLastError
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendGetFuncFromEnv
)
*>
(
flookup
(
"__TVMBackendGetFuncFromEnv"
)))
{
*
fp
=
TVMBackendGetFuncFromEnv
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendAllocWorkspace
)
*>
(
flookup
(
"__TVMBackendAllocWorkspace"
)))
{
*
fp
=
TVMBackendAllocWorkspace
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendFreeWorkspace
)
*>
(
flookup
(
"__TVMBackendFreeWorkspace"
)))
{
*
fp
=
TVMBackendFreeWorkspace
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendParallelFor
)
*>
(
flookup
(
"__TVMBackendParallelFor"
)))
{
*
fp
=
TVMBackendParallelFor
;
}
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
#endif // TVM_RUNTIME_MODULE_UTIL_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment