Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9037a4c2
Commit
9037a4c2
authored
Jul 18, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 18, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Enable injection of some core runtime functions to avoid dynamic lookup (#260)
parent
6196cd50
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
216 additions
and
81 deletions
+216
-81
Makefile
+3
-3
src/codegen/llvm/codegen_llvm.cc
+129
-50
src/codegen/llvm/codegen_llvm.h
+32
-11
src/codegen/llvm/llvm_common.cc
+0
-2
src/codegen/llvm/llvm_module.cc
+7
-6
src/runtime/dso_module.cc
+10
-9
src/runtime/module_util.h
+35
-0
No files found.
Makefile
View file @
9037a4c2
...
...
@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules.
build/codegen/llvm/%.o
:
src/codegen/llvm/%.cc
@
mkdir
-p
$
(
@D
)
$(CXX)
$(CFLAGS)
-MM
-MT
build/
$*
.o
$<
>
build
/
$*
.d
$(CXX)
$(CFLAGS)
-MM
-MT
build/
codegen/llvm/
$*
.o
$<
>
build/codegen/llvm
/
$*
.d
$(CXX)
-c
$(CFLAGS)
$(LLVM_CFLAGS)
-c
$<
-o
$@
build/runtime/metal/%.o
:
src/runtime/metal/%.mm
@
mkdir
-p
$
(
@D
)
$(CXX)
$(CFLAGS)
-MM
-MT
build/
$*
.o
$<
>
build
/
$*
.d
$(CXX)
$(CFLAGS)
-MM
-MT
build/
runtime/metal/
$*
.o
$<
>
build/runtime/metal
/
$*
.d
$(CXX)
$(OBJCFLAGS)
-c
$(CFLAGS)
-c
$<
-o
$@
build/%.o
:
src/%.cc
...
...
@@ -199,7 +199,7 @@ pylint:
pylint python/tvm
--rcfile
=
$(ROOTDIR)
/tests/lint/pylintrc
pylint topi/python/topi
--rcfile
=
$(ROOTDIR)
/tests/lint/pylintrc
jnilint
:
jnilint
:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint
:
cpplint pylint jnilint
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
9037a4c2
...
...
@@ -29,7 +29,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void
CodeGenLLVM
::
Init
(
const
std
::
string
&
module_name
,
llvm
::
TargetMachine
*
tm
,
llvm
::
LLVMContext
*
ctx
,
bool
system_lib
)
{
bool
system_lib
,
bool
dynamic_lookup
)
{
InitializeLLVM
();
static_assert
(
sizeof
(
TVMValue
)
==
sizeof
(
double
),
"invariant"
);
// static_assert(alignof(TVMValue) == alignof(double), "invariant");
...
...
@@ -62,7 +63,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_
->
getPointerTo
(),
t_int64_
});
t_tvm_value_
=
llvm
::
StructType
::
create
({
t_float64_
});
t_f
_tvm_par_for_lambda_
=
llvm
::
FunctionType
::
get
(
ftype
_tvm_par_for_lambda_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int64_
,
t_int64_
,
t_void_p_
},
false
);
md_builder_
.
reset
(
new
llvm
::
MDBuilder
(
*
ctx
));
md_very_likely_branch_
=
...
...
@@ -70,45 +71,56 @@ void CodeGenLLVM::Init(const std::string& module_name,
md_tbaa_root_
=
md_builder_
->
createTBAARoot
(
"tvmtbaa"
);
md_tbaa_alias_set_
=
md_builder_
->
createTBAAScalarTypeNode
(
"alias_set"
,
md_tbaa_root_
);
md_tbaa_ctx_ptr_
=
md_builder_
->
createTBAAScalarTypeNode
(
"ctx_ptr"
,
md_tbaa_root_
);
}
ctx_
=
ctx
;
// initialize
modules
// initialize
Modules and function type
module_
.
reset
(
new
llvm
::
Module
(
module_name
,
*
ctx
));
// initialize TVM runtime API
f_tvm_func_call_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_tvm_func_handle_
,
t_tvm_value_
->
getPointerTo
(),
t_int_
->
getPointerTo
(),
t_int_
,
t_tvm_value_
->
getPointerTo
(),
t_int_
->
getPointerTo
()},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMFuncCall"
,
module_
.
get
());
f_tvm_get_func_from_env_
=
llvm
::
Function
::
Create
(
ftype_tvm_func_call_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_tvm_func_handle_
,
t_tvm_value_
->
getPointerTo
(),
t_int_
->
getPointerTo
(),
t_int_
,
t_tvm_value_
->
getPointerTo
(),
t_int_
->
getPointerTo
()},
false
);
ftype_tvm_get_func_from_env_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_void_p_
,
t_char_
->
getPointerTo
(),
t_tvm_func_handle_
->
getPointerTo
()},
false
);
ftype_tvm_api_set_last_error_
=
llvm
::
FunctionType
::
get
(
t_void_
,
{
t_char_
->
getPointerTo
()},
false
);
ftype_tvm_parallel_for_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_void_p_
,
t_char_
->
getPointerTo
(),
t_tvm_func_handle_
->
getPointerTo
()},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendGetFuncFromEnv"
,
module_
.
get
());
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_void_
,
{
t_char_
->
getPointerTo
()},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
f_tvm_parallel_for_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int64_
,
t_int64_
,
t_f_tvm_par_for_lambda_
->
getPointerTo
(),
t_void_p_
}
,
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelFor"
,
module_
.
get
());
t_int64_
,
t_int64_
,
ftype_tvm_par_for_lambda_
->
getPointerTo
(),
t_void_p_
}
,
false
);
// initialize TVM runtime API
if
(
system_lib
)
{
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_char_
->
getPointerTo
(),
t_void_p_
},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendRegisterSystemLibSymbol"
,
module_
.
get
());
}
else
{
f_tvm_register_system_symbol_
=
nullptr
;
}
if
(
dynamic_lookup
||
system_lib
)
{
f_tvm_func_call_
=
llvm
::
Function
::
Create
(
ftype_tvm_func_call_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMFuncCall"
,
module_
.
get
());
f_tvm_get_func_from_env_
=
llvm
::
Function
::
Create
(
ftype_tvm_get_func_from_env_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendGetFuncFromEnv"
,
module_
.
get
());
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
ftype_tvm_api_set_last_error_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
f_tvm_parallel_for_
=
llvm
::
Function
::
Create
(
ftype_tvm_parallel_for_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelFor"
,
module_
.
get
());
}
this
->
InitTarget
(
tm
);
// initialize builder
builder_
.
reset
(
new
IRBuilder
(
*
ctx
));
this
->
InitGlobalContext
();
this
->
InitGlobalContext
(
dynamic_lookup
);
}
void
CodeGenLLVM
::
InitTarget
(
llvm
::
TargetMachine
*
tm
)
{
...
...
@@ -131,17 +143,48 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
}
}
void
CodeGenLLVM
::
InitGlobalContext
()
{
gv_mod_ctx_
=
new
llvm
::
GlobalVariable
(
*
module_
,
t_void_p_
,
false
,
llvm
::
GlobalVariable
*
CodeGenLLVM
::
InitContextPtr
(
llvm
::
Type
*
p_type
,
std
::
string
name
)
{
llvm
::
GlobalVariable
*
gv
=
new
llvm
::
GlobalVariable
(
*
module_
,
p_type
,
false
,
llvm
::
GlobalValue
::
LinkOnceAnyLinkage
,
0
,
tvm
::
runtime
::
symbol
::
tvm_module_ctx
);
gv_mod_ctx_
->
setAlignment
(
data_layout_
->
getTypeAllocSize
(
t_void_p_
));
gv_mod_ctx_
->
setInitializer
(
llvm
::
Constant
::
getNullValue
(
t_void_p_
));
name
);
gv
->
setAlignment
(
data_layout_
->
getTypeAllocSize
(
p_type
));
gv
->
setInitializer
(
llvm
::
Constant
::
getNullValue
(
p_type
));
return
gv
;
}
llvm
::
Value
*
CodeGenLLVM
::
GetContextPtr
(
llvm
::
GlobalVariable
*
gv
)
{
CHECK
(
gv
!=
nullptr
);
llvm
::
LoadInst
*
faddr
=
builder_
->
CreateAlignedLoad
(
gv
,
gv
->
getAlignment
());
faddr
->
setMetadata
(
"tbaa"
,
md_builder_
->
createTBAAStructTagNode
(
md_tbaa_ctx_ptr_
,
md_tbaa_ctx_ptr_
,
0
));
return
faddr
;
}
void
CodeGenLLVM
::
InitGlobalContext
(
bool
dynamic_lookup
)
{
// Module context
gv_mod_ctx_
=
InitContextPtr
(
t_void_p_
,
tvm
::
runtime
::
symbol
::
tvm_module_ctx
);
// Register back the locations.
if
(
f_tvm_register_system_symbol_
!=
nullptr
)
{
export_system_symbols_
.
emplace_back
(
std
::
make_pair
(
tvm
::
runtime
::
symbol
::
tvm_module_ctx
,
gv_mod_ctx_
));
}
else
{
if
(
!
dynamic_lookup
)
{
gv_tvm_func_call_
=
InitContextPtr
(
ftype_tvm_func_call_
->
getPointerTo
(),
"__TVMFuncCall"
);
gv_tvm_get_func_from_env_
=
InitContextPtr
(
ftype_tvm_get_func_from_env_
->
getPointerTo
(),
"__TVMBackendGetFuncFromEnv"
);
gv_tvm_api_set_last_error_
=
InitContextPtr
(
ftype_tvm_api_set_last_error_
->
getPointerTo
(),
"__TVMAPISetLastError"
);
gv_tvm_parallel_for_
=
InitContextPtr
(
ftype_tvm_parallel_for_
->
getPointerTo
(),
"__TVMBackendParallelFor"
);
// Mark as context functions
gv_func_map_
[
"TVMBackendAllocWorkspace"
]
=
nullptr
;
gv_func_map_
[
"TVMBackendFreeWorkspace"
]
=
nullptr
;
}
}
}
...
...
@@ -528,9 +571,13 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
// Initialize the handle if needed.
builder_
->
SetInsertPoint
(
init_block
);
llvm
::
Value
*
out
=
builder_
->
CreateAlloca
(
t_tvm_func_handle_
);
llvm
::
Value
*
ctx
=
builder_
->
CreateLoad
(
gv_mod_ctx_
);
llvm
::
LoadInst
*
ctx
=
builder_
->
CreateAlignedLoad
(
gv_mod_ctx_
,
gv_mod_ctx_
->
getAlignment
());
ctx
->
setMetadata
(
"tbaa"
,
md_builder_
->
createTBAAStructTagNode
(
md_tbaa_ctx_ptr_
,
md_tbaa_ctx_ptr_
,
0
));
llvm
::
Value
*
retcode
=
builder_
->
CreateCall
(
f_tvm_get_func_from_env_
,
{
ctx
,
GetConstString
(
fname
),
out
});
RuntimeTVMGetFuncFromEnv
()
,
{
ctx
,
GetConstString
(
fname
),
out
});
init_block
=
CheckCallSuccess
(
retcode
);
llvm
::
Value
*
loaded_handle
=
builder_
->
CreateAlignedLoad
(
out
,
align
);
builder_
->
CreateBr
(
end_block
);
...
...
@@ -565,7 +612,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
Int
(
32
),
stack_tcode
,
ConstInt32
(
end
));
CheckCallSuccess
(
builder_
->
CreateCall
(
f_tvm_func_call_
,
RuntimeTVMFuncCall
()
,
{
handle
,
arg_value
,
arg_tcode
,
ConstInt32
(
nargs
),
ret_value
,
ret_tcode
}));
Type
r_type
=
op
->
type
;
...
...
@@ -584,17 +631,28 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
arg_values
[
i
]
=
MakeValue
(
op
->
args
[
i
]);
}
if
(
op
->
type
.
is_scalar
())
{
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
if
(
f
==
nullptr
)
{
std
::
vector
<
llvm
::
Type
*>
arg_types
;
for
(
llvm
::
Value
*
v
:
arg_values
)
{
arg_types
.
push_back
(
v
->
getType
());
std
::
vector
<
llvm
::
Type
*>
arg_types
;
for
(
llvm
::
Value
*
v
:
arg_values
)
{
arg_types
.
push_back
(
v
->
getType
());
}
llvm
::
FunctionType
*
ftype
=
llvm
::
FunctionType
::
get
(
LLVMType
(
op
->
type
),
arg_types
,
false
);
// Check if it is available in global function table as injected function.
auto
it
=
gv_func_map_
.
find
(
op
->
name
);
if
(
it
!=
gv_func_map_
.
end
())
{
if
(
it
->
second
==
nullptr
)
{
gv_func_map_
[
op
->
name
]
=
InitContextPtr
(
ftype
->
getPointerTo
(),
"__"
+
op
->
name
);
it
=
gv_func_map_
.
find
(
op
->
name
);
}
f
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
LLVMType
(
op
->
type
),
arg_types
,
false
),
llvm
::
Function
::
ExternalLinkage
,
op
->
name
,
module_
.
get
());
return
builder_
->
CreateCall
(
GetContextPtr
(
it
->
second
),
arg_values
);
}
else
{
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
if
(
f
==
nullptr
)
{
f
=
llvm
::
Function
::
Create
(
ftype
,
llvm
::
Function
::
ExternalLinkage
,
op
->
name
,
module_
.
get
());
}
return
builder_
->
CreateCall
(
f
,
arg_values
);
}
return
builder_
->
CreateCall
(
f
,
arg_values
);
}
else
{
llvm
::
Function
*
f
=
module_
->
getFunction
(
op
->
name
);
if
(
f
)
{
...
...
@@ -603,6 +661,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
LOG
(
FATAL
)
<<
"cannot find function "
<<
op
->
name
;
}
}
LOG
(
FATAL
)
<<
"canot reach here"
;
return
nullptr
;
}
...
...
@@ -630,6 +689,24 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall(
return
value
;
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMFuncCall
()
{
if
(
f_tvm_func_call_
!=
nullptr
)
return
f_tvm_func_call_
;
return
GetContextPtr
(
gv_tvm_func_call_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMGetFuncFromEnv
()
{
if
(
f_tvm_get_func_from_env_
!=
nullptr
)
return
f_tvm_get_func_from_env_
;
return
GetContextPtr
(
gv_tvm_get_func_from_env_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMAPISetLastError
()
{
if
(
f_tvm_api_set_last_error_
!=
nullptr
)
return
f_tvm_api_set_last_error_
;
return
GetContextPtr
(
gv_tvm_api_set_last_error_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMParallelFor
()
{
if
(
f_tvm_parallel_for_
!=
nullptr
)
return
f_tvm_parallel_for_
;
return
GetContextPtr
(
gv_tvm_parallel_for_
);
}
llvm
::
Value
*
CodeGenLLVM
::
GetVarValue
(
const
Variable
*
v
)
const
{
auto
it
=
var_map_
.
find
(
v
);
CHECK
(
it
!=
var_map_
.
end
())
...
...
@@ -723,7 +800,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
t_f
_tvm_par_for_lambda_
,
ftype
_tvm_par_for_lambda_
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_par_for_lambda"
,
module_
.
get
());
// allocate and setup the closure, call the closure.
...
...
@@ -737,7 +814,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
}
BasicBlock
*
par_for_end
=
CheckCallSuccess
(
builder_
->
CreateCall
(
f_tvm_parallel_for_
,
RuntimeTVMParallelFor
()
,
{
min
,
extent
,
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
)}));
// Setup the closure function.
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
...
...
@@ -794,8 +871,9 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_
->
SetInsertPoint
(
for_end
);
}
llvm
::
Value
*
CodeGenLLVM
::
CreateIntrins
t
ic
(
const
Call
*
op
)
{
llvm
::
Value
*
CodeGenLLVM
::
CreateIntrinsic
(
const
Call
*
op
)
{
if
(
op
->
is_intrinsic
(
"llvm_intrin"
))
{
CHECK_GE
(
op
->
args
.
size
(),
1U
);
std
::
vector
<
llvm
::
Value
*>
arg_values
;
std
::
vector
<
llvm
::
Type
*>
arg_types
;
for
(
size_t
i
=
1
;
i
<
op
->
args
.
size
();
++
i
)
{
...
...
@@ -808,6 +886,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
module_
.
get
(),
id
,
arg_types
);
return
builder_
->
CreateCall
(
f
,
arg_values
);
}
else
if
(
op
->
is_intrinsic
(
"llvm_builtin"
))
{
CHECK_GE
(
op
->
args
.
size
(),
1U
);
std
::
vector
<
llvm
::
Value
*>
arg_values
;
for
(
size_t
i
=
1
;
i
<
op
->
args
.
size
();
++
i
)
{
llvm
::
Value
*
v
=
MakeValue
(
op
->
args
[
i
]);
...
...
@@ -1391,7 +1470,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
return
CreateCallPacked
(
op
);
}
else
if
(
op
->
call_type
==
Call
::
Intrinsic
||
op
->
call_type
==
Call
::
PureIntrinsic
)
{
return
CreateIntrins
t
ic
(
op
);
return
CreateIntrinsic
(
op
);
}
else
{
CHECK
(
op
->
call_type
==
Call
::
Extern
||
op
->
call_type
==
Call
::
PureExtern
);
...
...
@@ -1508,7 +1587,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_
->
CreateCondBr
(
cond
,
end_block
,
fail_block
,
md_very_likely_branch_
);
// fail condition.
builder_
->
SetInsertPoint
(
fail_block
);
builder_
->
CreateCall
(
f_tvm_api_set_last_error_
,
{
msg
});
builder_
->
CreateCall
(
RuntimeTVMAPISetLastError
()
,
{
msg
});
builder_
->
CreateRet
(
ConstInt32
(
-
1
));
// otherwise set it to be new end.
builder_
->
SetInsertPoint
(
end_block
);
...
...
src/codegen/llvm/codegen_llvm.h
View file @
9037a4c2
...
...
@@ -41,11 +41,14 @@ class CodeGenLLVM :
* \param tm Target machine model
* \param ctx The context.
* \param system_lib Whether to insert system library registration.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/
void
Init
(
const
std
::
string
&
module_name
,
llvm
::
TargetMachine
*
tm
,
llvm
::
LLVMContext
*
ctx
,
bool
system_lib
);
bool
system_lib
,
bool
dynamic_lookup
);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
...
...
@@ -114,7 +117,7 @@ class CodeGenLLVM :
void
VisitStmt_
(
const
Evaluate
*
op
)
override
;
void
VisitStmt_
(
const
ProducerConsumer
*
op
)
override
;
// create intrinstic given call
virtual
llvm
::
Value
*
CreateIntrins
t
ic
(
const
Call
*
op
);
virtual
llvm
::
Value
*
CreateIntrinsic
(
const
Call
*
op
);
// create extern function call
virtual
llvm
::
Value
*
CreateCallExtern
(
const
Call
*
op
);
// create call into tvm packed function.
...
...
@@ -178,6 +181,7 @@ class CodeGenLLVM :
llvm
::
MDNode
*
md_very_likely_branch_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_root_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_alias_set_
{
nullptr
};
llvm
::
MDNode
*
md_tbaa_ctx_ptr_
{
nullptr
};
// TVM related data types
llvm
::
Type
*
t_tvm_shape_index_
{
nullptr
};
llvm
::
Type
*
t_tvm_func_handle_
{
nullptr
};
...
...
@@ -185,13 +189,12 @@ class CodeGenLLVM :
llvm
::
StructType
*
t_tvm_type_
{
nullptr
};
llvm
::
StructType
*
t_tvm_array_
{
nullptr
};
llvm
::
StructType
*
t_tvm_value_
{
nullptr
};
llvm
::
FunctionType
*
t_f_tvm_par_for_lambda_
{
nullptr
};
// tvm api functions
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_for_
{
nullptr
};
llvm
::
Function
*
f_tvm_register_system_symbol_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_par_for_lambda_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_func_call_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_get_func_from_env_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_api_set_last_error_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_for_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_register_system_symbol_
{
nullptr
};
// The acting body
llvm
::
BasicBlock
*
block_
{
nullptr
};
/*! \brief native vector bits of current targetx*/
...
...
@@ -200,6 +203,13 @@ class CodeGenLLVM :
std
::
unordered_map
<
const
Variable
*
,
StorageInfo
>
alloc_storage_info_
;
private
:
// Get runtime functions
llvm
::
GlobalVariable
*
InitContextPtr
(
llvm
::
Type
*
type
,
std
::
string
name
);
llvm
::
Value
*
GetContextPtr
(
llvm
::
GlobalVariable
*
gv
);
llvm
::
Value
*
RuntimeTVMFuncCall
();
llvm
::
Value
*
RuntimeTVMGetFuncFromEnv
();
llvm
::
Value
*
RuntimeTVMAPISetLastError
();
llvm
::
Value
*
RuntimeTVMParallelFor
();
// comparison op
llvm
::
Value
*
GetVarValue
(
const
Variable
*
v
)
const
;
llvm
::
Value
*
CreateLT
(
Type
t
,
llvm
::
Value
*
a
,
llvm
::
Value
*
b
);
...
...
@@ -232,7 +242,7 @@ class CodeGenLLVM :
// return the end block after the check
llvm
::
BasicBlock
*
CheckCallSuccess
(
llvm
::
Value
*
retcode
);
// Add a function to set global module context
void
InitGlobalContext
();
void
InitGlobalContext
(
bool
dynamic_lookup
);
// Add module startup function if needed.
void
AddStartupFunction
();
// add alias information.
...
...
@@ -247,8 +257,19 @@ class CodeGenLLVM :
bool
is_restricted_
{
true
};
// set of var that are not restricted(can alias)
std
::
unordered_set
<
const
Variable
*>
alias_var_set_
;
//
The local module_context
//
Context for injection lookup
llvm
::
GlobalVariable
*
gv_mod_ctx_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_func_call_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_get_func_from_env_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_api_set_last_error_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_parallel_for_
{
nullptr
};
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
gv_func_map_
;
// context for direct dynamic lookup
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_for_
{
nullptr
};
llvm
::
Function
*
f_tvm_register_system_symbol_
{
nullptr
};
// global to packed function handle
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
func_handle_map_
;
// List of symbols to be exported to TVM system lib.
...
...
src/codegen/llvm/llvm_common.cc
View file @
9037a4c2
...
...
@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
return
tm
;
}
}
// namespace codegen
}
// namespace tvm
#endif // TVM_LLVM_VERSION
src/codegen/llvm/llvm_module.cc
View file @
9037a4c2
...
...
@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_
=
std
::
make_shared
<
llvm
::
LLVMContext
>
();
std
::
unique_ptr
<
CodeGenLLVM
>
cg
=
CodeGenLLVM
::
Create
(
tm_
);
entry_func_
=
funcs
[
0
]
->
name
;
cg
->
Init
(
funcs
[
0
]
->
name
,
tm_
,
ctx_
.
get
(),
system_lib
);
cg
->
Init
(
funcs
[
0
]
->
name
,
tm_
,
ctx_
.
get
(),
system_lib
,
system_lib
);
for
(
LoweredFunc
f
:
funcs
)
{
cg
->
AddFunction
(
f
);
}
...
...
@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<<
"Failed to initialize git engine for "
<<
mptr_
->
getTargetTriple
();
ee_
->
runStaticConstructorsDestructors
(
false
);
// setup context address.
void
**
ctx_addr
=
reinterpret_cast
<
void
**>
(
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_ctx
));
// setup context address.
entry_func_
=
reinterpret_cast
<
const
char
*>
(
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_main
));
if
(
ctx_addr
!=
nullptr
)
{
if
(
void
**
ctx_addr
=
reinterpret_cast
<
void
**>
(
ee_
->
getGlobalValueAddress
(
runtime
::
symbol
::
tvm_module_ctx
)))
{
*
ctx_addr
=
this
;
}
runtime
::
InitContextFunctions
([
this
](
const
char
*
name
)
{
auto
value
=
ee_
->
getGlobalValueAddress
(
name
);
return
value
;
});
}
// The target configuration string
std
::
string
target_
;
...
...
src/runtime/dso_module.cc
View file @
9037a4c2
...
...
@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
<<
"Symbol "
<<
runtime
::
symbol
::
tvm_module_main
<<
" is not presented"
;
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
entry_name
));
}
else
{
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
name
));
faddr
=
reinterpret_cast
<
BackendPackedCFunc
>
(
GetSymbol
(
name
.
c_str
()
));
}
if
(
faddr
==
nullptr
)
return
PackedFunc
();
return
WrapPackedFunc
(
faddr
,
sptr_to_self
);
...
...
@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
void
Init
(
const
std
::
string
&
name
)
{
Load
(
name
);
void
**
ctx_addr
=
reinterpret_cast
<
void
**>
(
GetSymbol
(
runtime
::
symbol
::
tvm_module_ctx
));
if
(
ctx_addr
!=
nullptr
)
{
if
(
auto
*
ctx_addr
=
reinterpret_cast
<
void
**>
(
GetSymbol
(
runtime
::
symbol
::
tvm_module_ctx
)))
{
*
ctx_addr
=
this
;
}
InitContextFunctions
([
this
](
const
char
*
fname
)
{
return
GetSymbol
(
fname
);
});
// Load the imported modules
const
char
*
dev_mblob
=
reinterpret_cast
<
const
char
*>
(
...
...
@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
CHECK
(
lib_handle_
!=
nullptr
)
<<
"Failed to load dynamic shared library "
<<
name
;
}
void
*
GetSymbol
(
const
std
::
string
&
name
)
{
void
*
GetSymbol
(
const
char
*
name
)
{
return
reinterpret_cast
<
void
*>
(
GetProcAddress
(
lib_handle_
,
(
LPCSTR
)
name
.
c_str
()
));
// NOLINT(*)
GetProcAddress
(
lib_handle_
,
(
LPCSTR
)
name
));
// NOLINT(*)
}
void
Unload
()
{
FreeLibrary
(
lib_handle_
);
...
...
@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
CHECK
(
lib_handle_
!=
nullptr
)
<<
"Failed to load dynamic shared library "
<<
name
;
}
void
*
GetSymbol
(
const
std
::
string
&
name
)
{
return
dlsym
(
lib_handle_
,
name
.
c_str
()
);
void
*
GetSymbol
(
const
char
*
name
)
{
return
dlsym
(
lib_handle_
,
name
);
}
void
Unload
()
{
dlclose
(
lib_handle_
);
...
...
src/runtime/module_util.h
View file @
9037a4c2
...
...
@@ -7,6 +7,8 @@
#define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <vector>
extern
"C"
{
...
...
@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
* \param module_list The module list to append to
*/
void
ImportModuleBlob
(
const
char
*
mblob
,
std
::
vector
<
Module
>*
module_list
);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template
<
typename
FLookup
>
void
InitContextFunctions
(
FLookup
flookup
)
{
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMFuncCall
)
*>
(
flookup
(
"__TVMFuncCall"
)))
{
*
fp
=
TVMFuncCall
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMAPISetLastError
)
*>
(
flookup
(
"__TVMAPISetLastError"
)))
{
*
fp
=
TVMAPISetLastError
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendGetFuncFromEnv
)
*>
(
flookup
(
"__TVMBackendGetFuncFromEnv"
)))
{
*
fp
=
TVMBackendGetFuncFromEnv
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendAllocWorkspace
)
*>
(
flookup
(
"__TVMBackendAllocWorkspace"
)))
{
*
fp
=
TVMBackendAllocWorkspace
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendFreeWorkspace
)
*>
(
flookup
(
"__TVMBackendFreeWorkspace"
)))
{
*
fp
=
TVMBackendFreeWorkspace
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendParallelFor
)
*>
(
flookup
(
"__TVMBackendParallelFor"
)))
{
*
fp
=
TVMBackendParallelFor
;
}
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment