Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b40d43c4
Commit
b40d43c4
authored
Aug 08, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 08, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS][RUNTIME] Support attr scope lift and runonce (#303)
parent
7d67e473
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
354 additions
and
28 deletions
+354
-28
include/tvm/ir.h
+5
-0
include/tvm/ir_pass.h
+9
-0
include/tvm/runtime/c_backend_api.h
+17
-0
python/tvm/build_module.py
+14
-7
src/api/api_pass.cc
+1
-0
src/codegen/llvm/codegen_llvm.cc
+85
-20
src/codegen/llvm/codegen_llvm.h
+9
-0
src/pass/combine_context_call.cc
+2
-1
src/pass/lift_attr_scope.cc
+151
-0
src/runtime/c_runtime_api.cc
+11
-0
tests/python/unittest/test_codegen_static_init.py
+27
-0
tests/python/unittest/test_pass_lift_attr_scope.py
+23
-0
No files found.
include/tvm/ir.h
View file @
b40d43c4
...
@@ -145,6 +145,11 @@ constexpr const char* thread_extent = "thread_extent";
...
@@ -145,6 +145,11 @@ constexpr const char* thread_extent = "thread_extent";
constexpr
const
char
*
virtual_thread
=
"virtual_thread"
;
constexpr
const
char
*
virtual_thread
=
"virtual_thread"
;
/*! \brief Mark region is processed by a co-proccesor */
/*! \brief Mark region is processed by a co-proccesor */
constexpr
const
char
*
coproc_scope
=
"coproc_scope"
;
constexpr
const
char
*
coproc_scope
=
"coproc_scope"
;
/*!
* \brief Mark region creates coprocessor micro ops,
* can be reused if corresponding variable is independent.
*/
constexpr
const
char
*
coproc_uop_scope
=
"coproc_uop_scope"
;
/*! \brief Mark the scope as volatile access for certain handle. */
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr
const
char
*
volatile_scope
=
"volatile_scope"
;
constexpr
const
char
*
volatile_scope
=
"volatile_scope"
;
/*!
/*!
...
...
include/tvm/ir_pass.h
View file @
b40d43c4
...
@@ -258,6 +258,15 @@ Stmt LoopPartition(Stmt stmt);
...
@@ -258,6 +258,15 @@ Stmt LoopPartition(Stmt stmt);
Stmt
CoProcSync
(
Stmt
stmt
);
Stmt
CoProcSync
(
Stmt
stmt
);
/*!
/*!
* \brief Lift common attrs with attr_key to outer scope.
*
* \param stmt The stmt to be trasnformed
* \param attr_key The attribute key to be checked.
* \return Transformed stmt.
*/
Stmt
LiftAttrScope
(
Stmt
stmt
,
std
::
string
attr_key
);
/*!
* \brief Make an user callable API LoweredFunc.
* \brief Make an user callable API LoweredFunc.
*
*
* The main task of this function is to create code to :
* The main task of this function is to create code to :
...
...
include/tvm/runtime/c_backend_api.h
View file @
b40d43c4
...
@@ -110,6 +110,23 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
...
@@ -110,6 +110,23 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
*/
*/
TVM_DLL
int
TVMBackendParallelBarrier
(
int
task_id
,
TVMParallelGroupEnv
*
penv
);
TVM_DLL
int
TVMBackendParallelBarrier
(
int
task_id
,
TVMParallelGroupEnv
*
penv
);
/*!
* \brief Simple static initialization fucntion.
* Run f once and set handle to be not null.
* This function is mainly used for test purpose.
*
* \param handle An global address to indicate f
* \param f The function to be ran
* \param cdata The closure data to pass to the function.
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL
int
TVMBackendRunOnce
(
void
**
handle
,
int
(
*
f
)(
void
*
),
void
*
cdata
,
int
nbytes
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// TVM_EXTERN_C
}
// TVM_EXTERN_C
#endif
#endif
...
...
python/tvm/build_module.py
View file @
b40d43c4
...
@@ -24,13 +24,14 @@ class BuildConfig(object):
...
@@ -24,13 +24,14 @@ class BuildConfig(object):
"""
"""
current
=
None
current
=
None
defaults
=
{
defaults
=
{
'auto_unroll_max_step'
:
0
,
"auto_unroll_max_step"
:
0
,
'auto_unroll_min_depth'
:
1
,
"auto_unroll_min_depth"
:
1
,
'unroll_explicit'
:
True
,
"unroll_explicit"
:
True
,
'detect_global_barrier'
:
False
,
"detect_global_barrier"
:
False
,
'offset_factor'
:
0
,
"offset_factor"
:
0
,
'data_alignment'
:
-
1
,
"data_alignment"
:
-
1
,
'restricted_func'
:
True
"restricted_func"
:
True
,
"add_lower_pass"
:
None
}
}
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
self
.
_old_scope
=
None
self
.
_old_scope
=
None
...
@@ -94,6 +95,9 @@ def build_config(**kwargs):
...
@@ -94,6 +95,9 @@ def build_config(**kwargs):
not to overlap. This enables more optimization.
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
Corresponds to restricted keyword in C99
add_lower_pass: list of function(Stmt->Stmt), default=None
Additional lowering passes to be applied before make_api.
Returns
Returns
-------
-------
config: BuildConfig
config: BuildConfig
...
@@ -200,6 +204,9 @@ def lower(sch,
...
@@ -200,6 +204,9 @@ def lower(sch,
cfg
.
auto_unroll_max_step
,
cfg
.
auto_unroll_max_step
,
cfg
.
auto_unroll_min_depth
,
cfg
.
auto_unroll_min_depth
,
cfg
.
unroll_explicit
)
cfg
.
unroll_explicit
)
if
cfg
.
add_lower_pass
:
for
f
in
cfg
.
add_lower_pass
:
stmt
=
f
(
stmt
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
if
simple_mode
:
if
simple_mode
:
return
stmt
return
stmt
...
...
src/api/api_pass.cc
View file @
b40d43c4
...
@@ -100,6 +100,7 @@ REGISTER_PASS1(InjectPrefetch);
...
@@ -100,6 +100,7 @@ REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS1
(
LoopPartition
);
REGISTER_PASS1
(
LoopPartition
);
REGISTER_PASS1
(
RemoveNoOp
);
REGISTER_PASS1
(
RemoveNoOp
);
REGISTER_PASS2
(
SplitPipeline
);
REGISTER_PASS2
(
SplitPipeline
);
REGISTER_PASS2
(
LiftAttrScope
);
REGISTER_PASS1
(
NarrowChannelAccess
);
REGISTER_PASS1
(
NarrowChannelAccess
);
REGISTER_PASS2
(
LowerThreadAllreduce
);
REGISTER_PASS2
(
LowerThreadAllreduce
);
REGISTER_PASS2
(
LowerIntrin
);
REGISTER_PASS2
(
LowerIntrin
);
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
b40d43c4
...
@@ -104,6 +104,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
...
@@ -104,6 +104,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
llvm
::
FunctionType
::
get
(
t_int_
,
{
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int_
,
t_tvm_parallel_group_env_
->
getPointerTo
()}
t_int_
,
t_tvm_parallel_group_env_
->
getPointerTo
()}
,
false
);
,
false
);
ftype_tvm_static_init_callback_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_void_p_
},
false
);
ftype_tvm_static_init_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_void_p_
->
getPointerTo
(),
ftype_tvm_static_init_callback_
->
getPointerTo
(),
t_void_p_
,
t_int_
}
,
false
);
// initialize TVM runtime API
// initialize TVM runtime API
if
(
system_lib
)
{
if
(
system_lib
)
{
// We will need this in environment for backward registration.
// We will need this in environment for backward registration.
...
@@ -802,30 +810,44 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
...
@@ -802,30 +810,44 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
builder_
->
SetInsertPoint
(
compute_call_end
);
builder_
->
SetInsertPoint
(
compute_call_end
);
}
}
void
CodeGenLLVM
::
CreateParallelLaunch
(
const
Stmt
&
body
,
int
num_task
)
{
llvm
::
Value
*
CodeGenLLVM
::
PackClosureData
(
const
Array
<
Var
>&
vfields
)
{
using
llvm
::
BasicBlock
;
Array
<
Var
>
vfields
=
ir
::
UndefinedVars
(
body
,
{});
std
::
vector
<
llvm
::
Type
*>
fields
;
std
::
vector
<
llvm
::
Type
*>
fields
;
for
(
Var
v
:
vfields
)
{
for
(
Var
v
:
vfields
)
{
auto
it
=
var_map_
.
find
(
v
.
get
());
auto
it
=
var_map_
.
find
(
v
.
get
());
CHECK
(
it
!=
var_map_
.
end
());
CHECK
(
it
!=
var_map_
.
end
());
fields
.
push_back
(
it
->
second
->
getType
());
fields
.
push_back
(
it
->
second
->
getType
());
}
}
// closure data
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
ftype_tvm_parallel_lambda_
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_parallel_lambda"
,
module_
.
get
());
// allocate and setup the closure, call the closure.
llvm
::
Value
*
cdata
=
builder_
->
CreateAlloca
(
tcdata
,
ConstInt32
(
1
));
llvm
::
Value
*
cdata
=
builder_
->
CreateAlloca
(
tcdata
,
ConstInt32
(
1
));
llvm
::
Value
*
zero
=
ConstInt32
(
0
);
llvm
::
Value
*
zero
=
ConstInt32
(
0
);
for
(
size_t
i
=
0
;
i
<
vfields
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vfields
.
size
();
++
i
)
{
builder_
->
CreateStore
(
builder_
->
CreateStore
(
var_map_
.
at
(
vfields
[
i
].
get
()),
var_map_
.
at
(
vfields
[
i
].
get
()),
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
}
return
cdata
;
}
void
CodeGenLLVM
::
UnpackClosureData
(
llvm
::
Value
*
cdata
,
const
Array
<
Var
>&
vfields
,
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>*
vmap
)
{
for
(
size_t
i
=
0
;
i
<
vfields
.
size
();
++
i
)
{
(
*
vmap
)[
vfields
[
i
].
get
()]
=
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
cdata
,
{
ConstInt32
(
0
),
ConstInt32
(
i
)}));
}
}
}
void
CodeGenLLVM
::
CreateParallelLaunch
(
const
Stmt
&
body
,
int
num_task
)
{
using
llvm
::
BasicBlock
;
// closure data
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
ftype_tvm_parallel_lambda_
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_parallel_lambda"
,
module_
.
get
());
// allocate and setup the closure, call the closure.
Array
<
Var
>
vfields
=
ir
::
UndefinedVars
(
body
,
{});
llvm
::
Value
*
cdata
=
PackClosureData
(
vfields
);
BasicBlock
*
par_launch_end
=
CheckCallSuccess
(
BasicBlock
*
par_launch_end
=
CheckCallSuccess
(
builder_
->
CreateCall
(
builder_
->
CreateCall
(
RuntimeTVMParallelLaunch
(),
RuntimeTVMParallelLaunch
(),
...
@@ -836,15 +858,10 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
...
@@ -836,15 +858,10 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
auto
it
=
f
->
arg_begin
();
auto
it
=
f
->
arg_begin
();
llvm
::
Value
*
task_id
=
&
(
*
it
++
);
llvm
::
Value
*
task_id
=
&
(
*
it
++
);
llvm
::
Value
*
penv
=
&
(
*
it
++
);
llvm
::
Value
*
penv
=
&
(
*
it
++
);
cdata
=
&
(
*
it
++
);
cdata
=
builder_
->
CreatePointerCast
(
&
(
*
it
++
),
cdata
->
getType
());
cdata
=
builder_
->
CreatePointerCast
(
cdata
,
tcdata
->
getPointerTo
());
// setup new variable map, swap it with current var context.
// setup new variable map, swap it with current var context.
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>
new_vmap
;
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>
new_vmap
;
for
(
size_t
i
=
0
;
i
<
vfields
.
size
();
++
i
)
{
UnpackClosureData
(
cdata
,
vfields
,
&
new_vmap
);
new_vmap
[
vfields
[
i
].
get
()]
=
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
}
// setup parallel env
// setup parallel env
ParallelEnv
par_env
;
ParallelEnv
par_env
;
par_env
.
task_id
=
Var
(
"task_id"
,
Int
(
32
));
par_env
.
task_id
=
Var
(
"task_id"
,
Int
(
32
));
...
@@ -852,7 +869,7 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
...
@@ -852,7 +869,7 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
new_vmap
[
par_env
.
task_id
.
get
()]
=
task_id
;
new_vmap
[
par_env
.
task_id
.
get
()]
=
task_id
;
new_vmap
[
par_env
.
num_task
.
get
()]
=
builder_
->
CreateLoad
(
new_vmap
[
par_env
.
num_task
.
get
()]
=
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
builder_
->
CreateInBoundsGEP
(
penv
,
{
zero
,
ConstInt32
(
1
)}));
penv
,
{
ConstInt32
(
0
)
,
ConstInt32
(
1
)}));
par_env
.
penv
=
penv
;
par_env
.
penv
=
penv
;
std
::
swap
(
function_
,
f
);
std
::
swap
(
function_
,
f
);
std
::
swap
(
parallel_env_
,
par_env
);
std
::
swap
(
parallel_env_
,
par_env
);
...
@@ -868,6 +885,52 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
...
@@ -868,6 +885,52 @@ void CodeGenLLVM::CreateParallelLaunch(const Stmt& body, int num_task) {
builder_
->
SetInsertPoint
(
par_launch_end
);
builder_
->
SetInsertPoint
(
par_launch_end
);
}
}
void
CodeGenLLVM
::
CreateStaticInit
(
const
std
::
string
&
init_fname
,
const
Stmt
&
body
)
{
using
llvm
::
BasicBlock
;
// closure data
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
ftype_tvm_static_init_callback_
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_static_init_lambda"
,
module_
.
get
());
llvm
::
GlobalVariable
*
gv
=
new
llvm
::
GlobalVariable
(
*
module_
,
t_void_p_
,
false
,
llvm
::
GlobalValue
::
PrivateLinkage
,
0
,
"__tvm_static_handle"
);
gv
->
setAlignment
(
data_layout_
->
getTypeAllocSize
(
t_void_p_
));
gv
->
setInitializer
(
llvm
::
Constant
::
getNullValue
(
t_void_p_
));
llvm
::
Function
*
finit
=
module_
->
getFunction
(
init_fname
);
if
(
finit
==
nullptr
)
{
finit
=
llvm
::
Function
::
Create
(
ftype_tvm_static_init_
,
llvm
::
Function
::
ExternalLinkage
,
init_fname
,
module_
.
get
());
}
// allocate and setup the closure, call the closure.
Array
<
Var
>
vfields
=
ir
::
UndefinedVars
(
body
,
{});
llvm
::
Value
*
cdata
=
PackClosureData
(
vfields
);
llvm
::
Value
*
nbytes
=
ConstInt32
(
data_layout_
->
getTypeAllocSize
(
llvm
::
cast
<
llvm
::
PointerType
>
(
cdata
->
getType
())
->
getElementType
()));
BasicBlock
*
init_end
=
CheckCallSuccess
(
builder_
->
CreateCall
(
finit
,
{
gv
,
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
),
nbytes
}));
// Setup the closure function.
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
builder_
->
SetInsertPoint
(
lambda_entry
);
auto
it
=
f
->
arg_begin
();
cdata
=
builder_
->
CreatePointerCast
(
&
(
*
it
++
),
cdata
->
getType
());
// setup new variable map, swap it with current var context.
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>
new_vmap
;
UnpackClosureData
(
cdata
,
vfields
,
&
new_vmap
);
CHECK
(
parallel_env_
.
penv
==
nullptr
);
std
::
swap
(
function_
,
f
);
std
::
swap
(
var_map_
,
new_vmap
);
this
->
VisitStmt
(
body
);
builder_
->
CreateRet
(
ConstInt32
(
0
));
// swap the var map back, now we are back on track.
std
::
swap
(
var_map_
,
new_vmap
);
std
::
swap
(
function_
,
f
);
builder_
->
SetInsertPoint
(
init_end
);
}
void
CodeGenLLVM
::
CreateSerialFor
(
llvm
::
Value
*
begin
,
void
CodeGenLLVM
::
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
llvm
::
Value
*
end
,
llvm
::
Value
*
stride
,
llvm
::
Value
*
stride
,
...
@@ -1626,6 +1689,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
...
@@ -1626,6 +1689,8 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
alloc_storage_info_
[
v
].
alignment
=
alloc_storage_info_
[
v
].
alignment
=
static_cast
<
int
>
(
op
->
value
.
as
<
IntImm
>
()
->
value
);
static_cast
<
int
>
(
op
->
value
.
as
<
IntImm
>
()
->
value
);
this
->
VisitStmt
(
op
->
body
);
this
->
VisitStmt
(
op
->
body
);
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
coproc_uop_scope
)
{
this
->
CreateStaticInit
(
op
->
value
.
as
<
StringImm
>
()
->
value
,
op
->
body
);
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
compute_scope
)
{
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
compute_scope
)
{
this
->
CreateComputeScope
(
op
);
this
->
CreateComputeScope
(
op
);
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
pragma_scope
)
{
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
pragma_scope
)
{
...
...
src/codegen/llvm/codegen_llvm.h
View file @
b40d43c4
...
@@ -197,6 +197,9 @@ class CodeGenLLVM :
...
@@ -197,6 +197,9 @@ class CodeGenLLVM :
llvm
::
FunctionType
*
ftype_tvm_parallel_launch_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_launch_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_barrier_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_barrier_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_register_system_symbol_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_register_system_symbol_
{
nullptr
};
// Lazy entry for function call.
llvm
::
FunctionType
*
ftype_tvm_static_init_callback_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_static_init_
{
nullptr
};
// The acting body
// The acting body
llvm
::
BasicBlock
*
block_
{
nullptr
};
llvm
::
BasicBlock
*
block_
{
nullptr
};
/*! \brief native vector bits of current targetx*/
/*! \brief native vector bits of current targetx*/
...
@@ -241,6 +244,12 @@ class CodeGenLLVM :
...
@@ -241,6 +244,12 @@ class CodeGenLLVM :
llvm
::
Value
*
CreateVecFlip
(
llvm
::
Value
*
vec
);
llvm
::
Value
*
CreateVecFlip
(
llvm
::
Value
*
vec
);
llvm
::
Value
*
CreateVecConcat
(
std
::
vector
<
llvm
::
Value
*>
vecs
);
llvm
::
Value
*
CreateVecConcat
(
std
::
vector
<
llvm
::
Value
*>
vecs
);
llvm
::
Value
*
CreateVecPad
(
llvm
::
Value
*
vec
,
int
target_lanes
);
llvm
::
Value
*
CreateVecPad
(
llvm
::
Value
*
vec
,
int
target_lanes
);
llvm
::
Value
*
PackClosureData
(
const
Array
<
Var
>&
fields
);
void
UnpackClosureData
(
llvm
::
Value
*
cdata
,
const
Array
<
Var
>&
fields
,
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>*
vmap
);
// Create static initialization
void
CreateStaticInit
(
const
std
::
string
&
init_fname
,
const
Stmt
&
body
);
// Create parallel launch
// Create parallel launch
void
CreateParallelLaunch
(
const
Stmt
&
body
,
int
num_task
);
void
CreateParallelLaunch
(
const
Stmt
&
body
,
int
num_task
);
// Create serial for
// Create serial for
...
...
src/pass/combine_context_call.cc
View file @
b40d43c4
...
@@ -47,7 +47,8 @@ class ContextCallCombiner final : public IRMutator {
...
@@ -47,7 +47,8 @@ class ContextCallCombiner final : public IRMutator {
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
if
(
op
->
attr_key
==
attr
::
thread_extent
||
op
->
attr_key
==
attr
::
coproc_uop_scope
)
{
// Map of comparison expression to variable
// Map of comparison expression to variable
std
::
map
<
Expr
,
Var
,
CompareExpr
>
temp
;
std
::
map
<
Expr
,
Var
,
CompareExpr
>
temp
;
std
::
swap
(
temp
,
ctx_map_
);
std
::
swap
(
temp
,
ctx_map_
);
...
...
src/pass/lift_attr_scope.cc
0 → 100644
View file @
b40d43c4
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Lift specified AttrStmt scope to outer if
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
namespace
tvm
{
namespace
ir
{
// NOTE: this optimization can only be applied
// to a few specified attr keys
class
AttrScopeLifter
:
public
IRMutator
{
public
:
explicit
AttrScopeLifter
(
std
::
string
attr_key
)
:
attr_key_
(
attr_key
)
{}
Stmt
Lift
(
Stmt
stmt
)
{
stmt
=
Mutate
(
stmt
);
if
(
attr_node_
.
defined
())
{
stmt
=
AttrStmt
::
make
(
attr_node_
,
attr_key_
,
attr_value_
,
stmt
);
}
return
stmt
;
}
// do not go beyond
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Allocate
>
();
if
(
attr_node_
.
defined
())
{
Stmt
body
=
AttrStmt
::
make
(
attr_node_
,
attr_key_
,
attr_value_
,
op
->
body
);
// undefine them
attr_node_
=
NodeRef
();
attr_value_
=
Expr
();
return
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
op
->
extents
,
op
->
condition
,
body
,
op
->
new_expr
,
op
->
free_function
);
}
else
{
return
stmt
;
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr_key_
)
{
attr_node_
=
op
->
node
;
attr_value_
=
op
->
value
;
return
op
->
body
;
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
)
final
{
Stmt
first
=
this
->
Mutate
(
op
->
first
);
NodeRef
first_node_
;
Expr
first_value_
;
std
::
swap
(
first_node_
,
attr_node_
);
std
::
swap
(
first_value_
,
attr_value_
);
Stmt
rest
=
this
->
Mutate
(
op
->
rest
);
if
(
attr_node_
.
defined
()
&&
attr_value_
.
defined
()
&&
first_node_
.
defined
()
&&
first_value_
.
defined
()
&&
attr_node_
.
same_as
(
first_node_
)
&&
attr_value_
.
same_as
(
first_value_
))
{
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
}
else
{
if
(
first_node_
.
defined
())
{
first
=
AttrStmt
::
make
(
first_node_
,
attr_key_
,
first_value_
,
first
);
}
if
(
attr_node_
.
defined
())
{
rest
=
AttrStmt
::
make
(
attr_node_
,
attr_key_
,
attr_value_
,
rest
);
// undefine them
attr_node_
=
NodeRef
();
attr_value_
=
Expr
();
}
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
}
}
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
)
final
{
if
(
!
op
->
then_case
.
defined
())
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
NodeRef
first_node_
;
Expr
first_value_
;
std
::
swap
(
first_node_
,
attr_node_
);
std
::
swap
(
first_value_
,
attr_value_
);
Stmt
else_case
=
this
->
Mutate
(
op
->
else_case
);
if
(
attr_node_
.
defined
()
&&
attr_value_
.
defined
()
&&
first_node_
.
defined
()
&&
first_value_
.
defined
()
&&
attr_node_
.
same_as
(
first_node_
)
&&
attr_value_
.
same_as
(
first_value_
))
{
if
(
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
op
->
condition
,
then_case
,
else_case
);
}
}
else
{
if
(
first_node_
.
defined
())
{
then_case
=
AttrStmt
::
make
(
first_node_
,
attr_key_
,
first_value_
,
then_case
);
}
if
(
attr_node_
.
defined
())
{
else_case
=
AttrStmt
::
make
(
attr_node_
,
attr_key_
,
attr_value_
,
else_case
);
// undefine them
attr_node_
=
NodeRef
();
attr_value_
=
Expr
();
}
if
(
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
op
->
condition
,
then_case
,
else_case
);
}
}
}
private
:
std
::
string
attr_key_
;
NodeRef
attr_node_
;
Expr
attr_value_
;
};
Stmt
LiftAttrScope
(
Stmt
stmt
,
std
::
string
attr_key
)
{
return
AttrScopeLifter
(
attr_key
).
Lift
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
src/runtime/c_runtime_api.cc
View file @
b40d43c4
...
@@ -234,6 +234,17 @@ int TVMBackendFreeWorkspace(int device_type,
...
@@ -234,6 +234,17 @@ int TVMBackendFreeWorkspace(int device_type,
return
0
;
return
0
;
}
}
int
TVMBackendRunOnce
(
void
**
handle
,
int
(
*
f
)(
void
*
),
void
*
cdata
,
int
nbytes
)
{
if
(
*
handle
==
nullptr
)
{
*
handle
=
reinterpret_cast
<
void
*>
(
1
);
return
(
*
f
)(
cdata
);
}
return
0
;
}
int
TVMFuncFree
(
TVMFunctionHandle
func
)
{
int
TVMFuncFree
(
TVMFunctionHandle
func
)
{
API_BEGIN
();
API_BEGIN
();
delete
static_cast
<
PackedFunc
*>
(
func
);
delete
static_cast
<
PackedFunc
*>
(
func
);
...
...
tests/python/unittest/test_codegen_static_init.py
0 → 100644
View file @
b40d43c4
import
tvm
import
numpy
as
np
def
test_static_init
():
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
Ab
=
tvm
.
decl_buffer
((
n
,
),
dtype
)
i
=
tvm
.
var
(
'i'
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
buffer_ptr
(
Ab
)
cp
=
tvm
.
thread_axis
((
0
,
1
),
"cop"
)
finit
=
tvm
.
make
.
StringImm
(
"TVMBackendRunOnce"
)
ib
.
scope_attr
(
cp
,
"coproc_uop_scope"
,
finit
)
with
ib
.
for_range
(
0
,
n
,
"i"
,
for_type
=
"parallel"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
fapi
=
tvm
.
ir_pass
.
LowerTVMBuiltin
(
fapi
)
f
=
tvm
.
codegen
.
build_module
(
fapi
,
"llvm"
)
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
f
(
a
)
f
(
a
)
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
np
.
ones
(
a
.
shape
[
0
]))
if
__name__
==
"__main__"
:
test_static_init
()
tests/python/unittest/test_pass_lift_attr_scope.py
0 → 100644
View file @
b40d43c4
import
tvm
def
test_coproc_lift
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
cp
=
tvm
.
thread_axis
((
0
,
1
),
"cop"
)
value
=
tvm
.
make
.
StringImm
(
"xxx"
)
A
=
ib
.
allocate
(
"float32"
,
n
,
name
=
"A"
,
scope
=
"global"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
ib
.
scope_attr
(
cp
,
"coproc_uop_scope"
,
value
)
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
ib
.
scope_attr
(
cp
,
"coproc_uop_scope"
,
value
)
A
[
j
]
=
A
[
j
]
+
2
body
=
ib
.
get
()
body
=
tvm
.
ir_pass
.
LiftAttrScope
(
body
,
"coproc_uop_scope"
)
assert
body
.
body
.
body
.
node
==
cp
if
__name__
==
"__main__"
:
test_coproc_lift
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment