Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b8f0ec50
Commit
b8f0ec50
authored
Feb 11, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 11, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG/PASS] InjectVirtualThread (#38)
parent
526ff04c
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
662 additions
and
85 deletions
+662
-85
include/tvm/ir.h
+24
-0
include/tvm/ir_mutator.h
+1
-0
include/tvm/ir_pass.h
+20
-0
python/tvm/build.py
+2
-0
src/api/api_pass.cc
+2
-0
src/arithmetic/canonical.cc
+2
-1
src/codegen/codegen_c.cc
+2
-2
src/codegen/codegen_cuda.cc
+12
-0
src/codegen/codegen_cuda.h
+6
-0
src/pass/inject_virtual_thread.cc
+419
-0
src/pass/ir_mutator.cc
+12
-10
src/pass/lift_allocate.cc
+96
-0
src/pass/storage_flatten.cc
+20
-69
src/runtime/thread_storage_scope.h
+5
-1
src/schedule/schedule_ops.cc
+11
-2
tests/python/unittest/test_pass_virtual_thread.py
+28
-0
No files found.
include/tvm/ir.h
View file @
b8f0ec50
...
...
@@ -49,6 +49,30 @@ struct Reduce : public ExprNode<Reduce> {
static
constexpr
const
char
*
Min
=
"Min"
;
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace
attr
{
/*!
* \brief Mark scope of iteration variable, used by Schedule.
*/
constexpr
const
char
*
scope
=
"scope"
;
/*!
* \brief Mark launching extent of thread, used by device API.
*/
constexpr
const
char
*
thread_extent
=
"thread_extent"
;
/*!
* \brief Mark launching of a virtual thread.
*/
constexpr
const
char
*
virtual_thread
=
"virtual_thread"
;
/*!
* \brief Mark storage scope of buffers
*/
constexpr
const
char
*
storage_scope
=
"storage_scope"
;
/*!
* \brief Mark storage scope of realizations
*/
constexpr
const
char
*
realize_scope
=
"realize_scope"
;
}
// namespace attr
/*! \brief namespace of TVM Intrinsic functions */
namespace
intrinsic
{
// Most of the intrinsics is to enab
...
...
include/tvm/ir_mutator.h
View file @
b8f0ec50
...
...
@@ -63,6 +63,7 @@ class IRMutator {
virtual
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Free
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
);
virtual
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
s
);
virtual
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
);
...
...
include/tvm/ir_pass.h
View file @
b8f0ec50
...
...
@@ -100,6 +100,7 @@ Stmt Inline(Stmt stmt,
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt
StorageFlatten
(
Stmt
stmt
,
Map
<
Tensor
,
Buffer
>
extern_buffer
);
...
...
@@ -108,16 +109,35 @@ Stmt StorageFlatten(Stmt stmt,
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
* \return Transformed stmt.
*/
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
);
/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \return Transformed stmt.
*/
Stmt
VectorizeLoop
(
Stmt
stmt
);
/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \return Transformed stmt.
*/
Stmt
InjectVirtualThread
(
Stmt
stmt
);
/*!
* \brief Lift storage allocation to relevant outpost location
*
* Only do this after vectorization and virtual thread injection completes.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt
LiftAllocate
(
Stmt
stmt
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
python/tvm/build.py
View file @
b8f0ec50
...
...
@@ -70,6 +70,8 @@ def build(sch,
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
LiftAllocate
(
stmt
)
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
max_auto_unroll_step
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
fapi
=
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
len
(
arg_list
))
...
...
src/api/api_pass.cc
View file @
b8f0ec50
...
...
@@ -67,6 +67,8 @@ REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2
(
StorageSync
);
REGISTER_PASS4
(
MakeAPI
);
REGISTER_PASS1
(
SplitHostDevice
);
REGISTER_PASS1
(
LiftAllocate
);
REGISTER_PASS1
(
InjectVirtualThread
);
}
// namespace ir
}
// namespace tvm
src/arithmetic/canonical.cc
View file @
b8f0ec50
...
...
@@ -288,7 +288,8 @@ class Canonical::Internal : public IRMutator {
}
// AttrStmt
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
{
if
(
op
->
type_key
==
"thread_extent"
)
{
if
(
op
->
type_key
==
attr
::
thread_extent
||
op
->
type_key
==
attr
::
virtual_thread
)
{
++
level_counter_
;
IterVar
iv
(
op
->
node
.
node_
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
...
...
src/codegen/codegen_c.cc
View file @
b8f0ec50
...
...
@@ -743,7 +743,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
}
void
CodeGenC
::
PrintStmt
(
const
AttrStmt
*
op
)
{
if
(
op
->
type_key
==
"scope"
)
{
if
(
op
->
type_key
==
ir
::
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
if
(
iv
->
thread_tag
.
length
()
!=
0
)
{
if
(
!
var_idmap_
.
count
(
iv
->
var
.
get
()))
{
...
...
@@ -756,7 +756,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
stream
<<
";
\n
"
;
}
}
}
else
if
(
op
->
type_key
==
"storage_scope"
)
{
}
else
if
(
op
->
type_key
==
ir
::
attr
::
storage_scope
)
{
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
CHECK
(
v
);
alloc_storage_scope_
[
v
]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
...
...
src/codegen/codegen_cuda.cc
View file @
b8f0ec50
...
...
@@ -9,6 +9,7 @@
#include <string>
#include "./codegen_cuda.h"
#include "./codegen_stack_vm.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"
...
...
@@ -22,6 +23,17 @@ std::string CodeGenCUDA::Compile(
return
CodeGenC
::
Compile
(
f
,
output_ssa
);
}
void
CodeGenCUDA
::
PrintStmt
(
const
ir
::
For
*
op
)
{
int
ext
;
CHECK
(
is_zero
(
op
->
min
));
if
(
arith
::
GetConstInt
(
op
->
extent
,
&
ext
)
&&
ext
<=
max_auto_unroll_
)
{
PrintIndent
();
stream
<<
"#pragma unroll
\n
"
;
}
CodeGenC
::
PrintStmt
(
op
);
}
void
CodeGenCUDA
::
PrintType
(
Type
t
,
std
::
ostream
&
os
)
const
{
// NOLINT(*)
int
lanes
=
t
.
lanes
();
if
(
t
.
is_handle
())
{
...
...
src/codegen/codegen_cuda.h
View file @
b8f0ec50
...
...
@@ -27,6 +27,7 @@ class CodeGenCUDA : public CodeGenC {
bool
output_ssa
);
// override behavior
void
PrintStmt
(
const
ir
::
For
*
op
)
final
;
void
PrintStorageSync
(
const
std
::
string
&
sync
)
final
;
void
PrintStorageScope
(
const
std
::
string
&
scope
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecBinaryOp
(
...
...
@@ -37,6 +38,11 @@ class CodeGenCUDA : public CodeGenC {
const
std
::
string
&
vec
,
Type
t
,
int
i
,
std
::
ostream
&
os
)
final
;
// NOLINT(*)
void
PrintVecElemStore
(
const
std
::
string
&
vec
,
Type
t
,
int
i
,
const
std
::
string
&
value
)
final
;
private
:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int
max_auto_unroll_
{
8
};
};
}
// namespace codegen
...
...
src/pass/inject_virtual_thread.cc
0 → 100644
View file @
b8f0ec50
/*!
* Copyright (c) 2017 by Contributors
* \file inject_virtual_thread.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
ir
{
// If expression is touched by var.
class
ExprTouched
:
public
IRVisitor
{
public
:
explicit
ExprTouched
(
const
std
::
unordered_set
<
const
Variable
*>
&
touched
)
:
touched_var_
(
touched
)
{}
void
Visit
(
const
NodeRef
&
n
)
final
{
// early stopping
if
(
expr_touched_
)
return
;
IRVisitor
::
Visit
(
n
);
}
void
Visit_
(
const
Load
*
op
)
final
{
HandleUseVar
(
op
->
buffer_var
.
get
());
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Variable
*
op
)
final
{
HandleUseVar
(
op
);
}
void
HandleUseVar
(
const
Variable
*
var
)
{
auto
it
=
touched_var_
.
find
(
var
);
if
(
it
!=
touched_var_
.
end
())
{
expr_touched_
=
true
;
}
// rember the used vars
// in case the var get touched later in a loop.
if
(
!
expr_touched_
)
{
used_vars_
.
push_back
(
var
);
}
}
// the fields.
bool
expr_touched_
{
false
};
std
::
vector
<
const
Variable
*>
used_vars_
;
const
std
::
unordered_set
<
const
Variable
*>&
touched_var_
;
};
// Analyze if the buffers are invariant to value of var
class
VarTouchedAnalysis
:
public
IRVisitor
{
public
:
void
Visit_
(
const
LetStmt
*
op
)
{
ExprTouched
tc
(
touched_var_
);
tc
.
Visit
(
op
->
value
);
Record
(
op
->
var
.
get
(),
tc
);
this
->
Visit
(
op
->
body
);
}
void
Visit_
(
const
Store
*
op
)
{
ExprTouched
tc
(
touched_var_
);
tc
.
Visit
(
op
->
value
);
tc
.
Visit
(
op
->
index
);
Record
(
op
->
buffer_var
.
get
(),
tc
);
}
void
Visit_
(
const
For
*
op
)
{
ExprTouched
tc
(
touched_var_
);
tc
.
Visit
(
op
->
min
);
tc
.
Visit
(
op
->
extent
);
Record
(
op
->
loop_var
.
get
(),
tc
);
this
->
Visit
(
op
->
body
);
}
void
Visit_
(
const
Allocate
*
op
)
{
ExprTouched
tc
(
touched_var_
);
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
++
i
)
{
tc
.
Visit
(
op
->
extents
[
i
]);
}
tc
.
Visit
(
op
->
condition
);
if
(
op
->
new_expr
.
defined
())
{
tc
.
Visit
(
op
->
new_expr
);
}
Record
(
op
->
buffer_var
.
get
(),
tc
);
this
->
Visit
(
op
->
body
);
}
void
Record
(
const
Variable
*
var
,
const
ExprTouched
&
tc
)
{
if
(
touched_var_
.
count
(
var
))
return
;
if
(
tc
.
expr_touched_
)
{
touched_var_
.
insert
(
var
);
}
else
{
for
(
const
Variable
*
r
:
tc
.
used_vars_
)
{
affect_
[
r
].
push_back
(
var
);
}
}
}
std
::
unordered_set
<
const
Variable
*>
TouchedVar
(
const
Stmt
&
stmt
,
const
Variable
*
var
)
{
touched_var_
.
insert
(
var
);
this
->
Visit
(
stmt
);
// do a DFS to push affect around dependency.
std
::
vector
<
const
Variable
*>
pending
(
touched_var_
.
begin
(),
touched_var_
.
end
());
while
(
!
pending
.
empty
())
{
const
Variable
*
v
=
pending
.
back
();
pending
.
pop_back
();
for
(
const
Variable
*
r
:
affect_
[
v
])
{
if
(
!
touched_var_
.
count
(
r
))
{
touched_var_
.
insert
(
r
);
pending
.
push_back
(
r
);
}
}
}
return
std
::
move
(
touched_var_
);
}
private
:
// Whether variable is touched by the thread variable.
std
::
unordered_set
<
const
Variable
*>
touched_var_
;
// x -> all the buffers x read from
std
::
unordered_map
<
const
Variable
*
,
std
::
vector
<
const
Variable
*>
>
affect_
;
};
// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
class
VTInjector
:
public
IRMutator
{
public
:
using
IRMutator
::
Mutate
;
// constructor
VTInjector
(
Var
var
,
int
num_threads
,
std
::
unordered_set
<
const
Variable
*>
touched_var
)
:
var_
(
var
),
num_threads_
(
num_threads
),
touched_var_
(
touched_var
)
{
}
// Inject VTLoop when needed.
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
!
visit_touched_var_
)
<<
stmt
->
type_key
()
<<
stmt
;
stmt
=
IRMutator
::
Mutate
(
stmt
);
if
(
visit_touched_var_
)
{
if
(
!
vt_loop_injected_
)
return
InjectVTLoop
(
stmt
,
false
);
visit_touched_var_
=
false
;
}
return
stmt
;
}
// Variable
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
if
(
touched_var_
.
count
(
op
))
{
visit_touched_var_
=
true
;
}
return
e
;
}
Expr
RewriteIndex
(
Expr
index
,
Expr
alloc_extent
)
const
{
if
(
index_rewrite_strategy_
==
0
)
{
return
index
*
num_threads_
+
var_
;
}
else
{
return
index
+
var_
*
alloc_extent
;
}
}
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
()))
{
visit_touched_var_
=
true
;
}
auto
it
=
touched_alloc_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
touched_alloc_
.
end
())
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
RewriteIndex
(
op
->
index
,
it
->
second
));
}
else
{
return
expr
;
}
}
// Store
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Store
>
();
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
()))
{
visit_touched_var_
=
true
;
}
auto
it
=
touched_alloc_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
touched_alloc_
.
end
())
{
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
RewriteIndex
(
op
->
index
,
it
->
second
));
}
else
{
return
stmt
;
}
}
// Attribute
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type_key
==
attr
::
scope
)
{
return
Mutate
(
op
->
body
);
}
else
{
Expr
value
=
Mutate
(
op
->
value
);
if
(
visit_touched_var_
)
{
return
InjectVTLoop
(
s
,
true
);
}
else
{
Stmt
body
=
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
value
,
body
);
}
}
}
}
// LetStmt
Stmt
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
if
(
visit_touched_var_
&&
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
s
,
true
);
}
visit_touched_var_
=
false
;
Stmt
body
=
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
LetStmt
::
make
(
op
->
var
,
value
,
body
);
}
}
// For
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
CHECK
(
is_zero
(
op
->
min
));
Expr
extent
=
Mutate
(
op
->
extent
);
if
(
visit_touched_var_
&&
!
vt_loop_injected_
)
{
Stmt
stmt
=
InjectVTLoop
(
s
,
true
);
++
max_loop_depth_
;
return
stmt
;
}
visit_touched_var_
=
false
;
Stmt
body
=
Mutate
(
op
->
body
);
++
max_loop_depth_
;
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
for_type
,
op
->
device_api
,
body
);
}
}
// IfThenElse
Stmt
Mutate_
(
const
IfThenElse
*
op
,
const
Stmt
&
s
)
final
{
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
if
(
visit_touched_var_
&&
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
s
,
true
);
}
visit_touched_var_
=
false
;
CHECK_EQ
(
max_loop_depth_
,
0
);
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
Stmt
else_case
;
if
(
else_case
.
defined
())
{
int
temp
=
max_loop_depth_
;
max_loop_depth_
=
0
;
else_case
=
this
->
Mutate
(
op
->
else_case
);
max_loop_depth_
=
std
::
max
(
temp
,
max_loop_depth_
);
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
condition
,
then_case
,
else_case
);
}
}
// Block
Stmt
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
)
final
{
CHECK_EQ
(
max_loop_depth_
,
0
);
Stmt
first
=
this
->
Mutate
(
op
->
first
);
int
temp
=
max_loop_depth_
;
max_loop_depth_
=
0
;
Stmt
rest
=
this
->
Mutate
(
op
->
rest
);
max_loop_depth_
=
std
::
max
(
max_loop_depth_
,
temp
);
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
}
// Allocate
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
new_expr
.
defined
()
&&
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
s
,
true
);
}
Expr
condition
=
Mutate
(
op
->
condition
);
if
(
visit_touched_var_
&&
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
s
,
true
);
}
bool
changed
=
false
;
Array
<
Expr
>
extents
;
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
Expr
new_ext
=
Mutate
(
op
->
extents
[
i
]);
if
(
visit_touched_var_
&&
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
s
,
true
);
}
if
(
!
new_ext
.
same_as
(
op
->
extents
[
i
]))
changed
=
true
;
extents
.
push_back
(
new_ext
);
}
visit_touched_var_
=
false
;
Stmt
body
;
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
()))
{
// place v on highest dimension.
Expr
stride
=
extents
[
0
];
for
(
size_t
i
=
1
;
i
<
extents
.
size
();
++
i
)
{
stride
=
arith
::
ComputeExpr
<
Mul
>
(
stride
,
extents
[
i
]);
}
Array
<
Expr
>
other
;
other
.
push_back
(
num_threads_
);
for
(
Expr
e
:
extents
)
{
other
.
push_back
(
e
);
}
extents
=
other
;
changed
=
true
;
// mark this buffer get touched.
touched_alloc_
[
op
->
buffer_var
.
get
()]
=
stride
;
// Mutate the body.
body
=
Mutate
(
op
->
body
);
}
else
{
// Mutate the body.
body
=
Mutate
(
op
->
body
);
}
if
(
!
changed
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
))
{
return
s
;
}
else
{
return
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
extents
,
condition
,
body
,
op
->
new_expr
,
op
->
free_function
);
}
}
// inject vthread loop
Stmt
InjectVTLoop
(
Stmt
stmt
,
bool
before_mutation
)
{
CHECK
(
!
vt_loop_injected_
);
// reset the flags
visit_touched_var_
=
false
;
vt_loop_injected_
=
true
;
if
(
before_mutation
)
{
stmt
=
this
->
Mutate
(
stmt
);
}
// reset the flags after processing.
vt_loop_injected_
=
false
;
visit_touched_var_
=
false
;
if
(
max_loop_depth_
==
0
)
{
// do unrolling if it is inside innermost content.
Stmt
blk
=
Substitute
(
stmt
,
{{
var_
,
make_zero
(
var_
.
type
())}});
for
(
int
i
=
1
;
i
<
num_threads_
;
++
i
)
{
blk
=
Block
::
make
(
blk
,
Substitute
(
stmt
,
{{
var_
,
make_const
(
var_
.
type
(),
i
)}}));
}
return
blk
;
}
else
{
// insert a for loop
Var
idx
(
var_
->
name_hint
+
".s"
,
var_
->
type
);
stmt
=
Substitute
(
stmt
,
{{
var_
,
idx
}});
return
For
::
make
(
idx
,
0
,
num_threads_
,
ForType
::
Serial
,
DeviceAPI
::
None
,
stmt
);
}
}
private
:
// vthread variable
Var
var_
;
// the threads/lanes
int
num_threads_
;
// Index rewriting strategy
int
index_rewrite_strategy_
{
1
};
// whethe the loop is already injected.
bool
vt_loop_injected_
{
false
};
// whether current expression get touched.
bool
visit_touched_var_
{
false
};
// the counter of loops in after mutation.
int
max_loop_depth_
{
0
};
// The variables that get touched.
std
::
unordered_set
<
const
Variable
*>
touched_var_
;
// The allocations that get touched -> extent
std
::
unordered_map
<
const
Variable
*
,
Expr
>
touched_alloc_
;
};
class
VirtualThreadInjector
:
public
IRMutator
{
public
:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
->
type_key
==
attr
::
virtual_thread
)
{
IterVar
iv
(
op
->
node
.
node_
);
int
nthread
=
static_cast
<
int
>
(
op
->
value
.
as
<
IntImm
>
()
->
value
);
VarTouchedAnalysis
vs
;
auto
touched
=
vs
.
TouchedVar
(
op
->
body
,
iv
->
var
.
get
());
VTInjector
injecter
(
iv
->
var
,
nthread
,
touched
);
return
injecter
.
Mutate
(
op
->
body
);
}
else
{
return
stmt
;
}
}
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
LOG
(
FATAL
)
<<
"Need to call StorageFlatten first"
;
return
s
;
}
};
Stmt
InjectVirtualThread
(
Stmt
stmt
)
{
stmt
=
VirtualThreadInjector
().
Mutate
(
stmt
);
return
ConvertSSA
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
src/pass/ir_mutator.cc
View file @
b8f0ec50
...
...
@@ -77,6 +77,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.
DISPATCH_TO_MUTATE_STMT
(
IfThenElse
)
.
DISPATCH_TO_MUTATE_STMT
(
For
)
.
DISPATCH_TO_MUTATE_STMT
(
Allocate
)
.
DISPATCH_TO_MUTATE_STMT
(
Block
)
.
DISPATCH_TO_MUTATE_STMT
(
Free
);
Stmt
IRMutator
::
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
{
...
...
@@ -212,6 +213,17 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
}
}
Stmt
IRMutator
::
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
)
{
Stmt
first
=
this
->
Mutate
(
op
->
first
);
Stmt
rest
=
this
->
Mutate
(
op
->
rest
);
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
DISPATCH_TO_MUTATE_EXPR
(
Call
)
.
DISPATCH_TO_MUTATE_EXPR
(
Let
)
...
...
@@ -370,16 +382,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return
ProducerConsumer
::
make
(
op
->
func
,
op
->
is_producer
,
body
);
}
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
first
=
m
->
Mutate
(
op
->
first
);
Stmt
rest
=
m
->
Mutate
(
op
->
rest
);
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
v
=
m
->
Mutate
(
op
->
value
);
if
(
v
.
same_as
(
op
->
value
))
{
...
...
src/pass/lift_allocate.cc
0 → 100644
View file @
b8f0ec50
/*!
* Copyright (c) 2017 by Contributors
* \file lift_allocate.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
namespace
ir
{
using
runtime
::
StorageScope
;
using
runtime
::
ThreadScope
;
class
AllocateLifter
:
public
IRMutator
{
public
:
Stmt
Lift
(
Stmt
stmt
)
{
stmt
=
this
->
Mutate
(
stmt
);
StorageScope
key
;
key
.
rank
=
0
;
stmt
=
MergeNest
(
allocs_
[
key
],
stmt
);
return
stmt
;
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
CHECK
(
op
->
type_key
!=
attr
::
virtual_thread
)
<<
"InjectVirtualThread before LiftStorageAlloc"
;
if
(
op
->
type_key
==
attr
::
storage_scope
)
{
StorageScope
sc
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
allocs_
[
sc
].
emplace_back
(
AttrStmt
::
make
(
op
->
node
,
attr
::
storage_scope
,
op
->
value
,
Evaluate
::
make
(
0
)));
storage_scope_
[
op
->
node
.
get
()]
=
sc
;
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
curr_thread_scope_
.
push_back
(
ts
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
curr_thread_scope_
.
pop_back
();
op
=
stmt
.
as
<
AttrStmt
>
();
bool
first_scope
=
true
;
for
(
const
ThreadScope
&
t
:
curr_thread_scope_
)
{
if
(
t
.
rank
==
ts
.
rank
)
first_scope
=
false
;
}
if
(
first_scope
)
{
StorageScope
key
;
key
.
rank
=
ts
.
rank
+
1
;
std
::
vector
<
Stmt
>&
vec
=
allocs_
[
key
];
if
(
vec
.
size
()
!=
0
)
{
Stmt
body
=
MergeNest
(
vec
,
op
->
body
);
vec
.
clear
();
return
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
op
->
value
,
body
);
}
}
return
stmt
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
CHECK
(
op
->
for_type
!=
ForType
::
Vectorized
)
<<
"VectorizeLoop before LiftStorageAlloc"
;
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
auto
it
=
storage_scope_
.
find
(
op
->
buffer_var
.
get
());
CHECK
(
it
!=
storage_scope_
.
end
());
allocs_
[
it
->
second
].
emplace_back
(
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
op
->
extents
,
op
->
condition
,
Evaluate
::
make
(
0
)));
return
this
->
Mutate
(
op
->
body
);
}
private
:
// storage scope of internal allocation.
std
::
unordered_map
<
const
Node
*
,
StorageScope
>
storage_scope_
;
// The current thread scope.
std
::
vector
<
ThreadScope
>
curr_thread_scope_
;
// The allocations by rank
std
::
unordered_map
<
StorageScope
,
std
::
vector
<
Stmt
>
>
allocs_
;
};
Stmt
LiftAllocate
(
Stmt
stmt
)
{
return
AllocateLifter
().
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
src/pass/storage_flatten.cc
View file @
b8f0ec50
...
...
@@ -6,7 +6,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
...
...
@@ -61,46 +60,17 @@ class StorageFlattener : public IRMutator {
}
}
Stmt
Flatten
(
Stmt
stmt
)
{
stmt
=
this
->
Mutate
(
stmt
);
StorageScope
key
;
key
.
rank
=
0
;
if
(
move_alloc_out_
)
{
StorageScope
key
;
key
.
rank
=
0
;
stmt
=
MergeNest
(
allocs_
[
key
],
stmt
);
}
return
stmt
;
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type_key
==
"realize_scope"
)
{
if
(
op
->
type_key
==
attr
::
realize_scope
)
{
storage_scope_
[
op
->
node
.
get
()]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
"scope"
)
{
}
else
if
(
op
->
type_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
if
(
iv
->
thread_tag
.
length
()
!=
0
)
{
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
curr_thread_scope_
.
push_back
(
ts
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
curr_thread_scope_
.
pop_back
();
op
=
stmt
.
as
<
AttrStmt
>
();
bool
first_scope
=
true
;
for
(
const
ThreadScope
&
t
:
curr_thread_scope_
)
{
if
(
t
.
rank
==
ts
.
rank
)
first_scope
=
false
;
}
if
(
first_scope
&&
move_alloc_out_
)
{
StorageScope
key
;
key
.
rank
=
ts
.
rank
+
1
;
std
::
vector
<
Stmt
>&
vec
=
allocs_
[
key
];
if
(
vec
.
size
()
!=
0
)
{
Stmt
body
=
MergeNest
(
vec
,
op
->
body
);
vec
.
clear
();
return
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
op
->
value
,
body
);
}
}
return
stmt
;
}
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
curr_thread_scope_
.
push_back
(
ts
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
curr_thread_scope_
.
pop_back
();
return
stmt
;
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
...
...
@@ -140,37 +110,22 @@ class StorageFlattener : public IRMutator {
// deduce current storage scope.
auto
it
=
storage_scope_
.
find
(
op
->
func
.
get
());
CHECK
(
it
!=
storage_scope_
.
end
());
StorageScope
key
;
key
.
rank
=
0
;
const
std
::
string
&
skey
=
it
->
second
;
if
(
skey
.
length
()
==
0
)
{
StorageScope
skey
;
const
std
::
string
&
s
tr
key
=
it
->
second
;
if
(
s
tr
key
.
length
()
==
0
)
{
if
(
curr_thread_scope_
.
size
()
!=
0
)
{
key
.
rank
=
curr_thread_scope_
.
back
().
rank
+
1
;
s
key
.
rank
=
curr_thread_scope_
.
back
().
rank
+
1
;
}
}
else
{
key
=
StorageScope
::
make
(
skey
);
}
if
(
move_alloc_out_
)
{
allocs_
[
key
].
push_back
(
AttrStmt
::
make
(
e
.
buffer
->
data
,
"storage_scope"
,
StringImm
::
make
(
key
.
to_string
()),
Evaluate
::
make
(
0
)));
allocs_
[
key
].
push_back
(
Allocate
::
make
(
e
.
buffer
->
data
,
e
.
buffer
->
dtype
,
e
.
buffer
->
shape
,
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
Evaluate
::
make
(
0
)));
return
body
;
}
else
{
Stmt
ret
=
Allocate
::
make
(
e
.
buffer
->
data
,
e
.
buffer
->
dtype
,
e
.
buffer
->
shape
,
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
ret
=
AttrStmt
::
make
(
e
.
buffer
->
data
,
"storage_scope"
,
StringImm
::
make
(
key
.
to_string
()),
ret
);
return
ret
;
skey
=
StorageScope
::
make
(
strkey
);
}
Stmt
ret
=
Allocate
::
make
(
e
.
buffer
->
data
,
e
.
buffer
->
dtype
,
e
.
buffer
->
shape
,
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
ret
=
AttrStmt
::
make
(
e
.
buffer
->
data
,
attr
::
storage_scope
,
StringImm
::
make
(
skey
.
to_string
()),
ret
);
return
ret
;
}
}
...
...
@@ -217,20 +172,16 @@ class StorageFlattener : public IRMutator {
}
}
};
// whether move allocation to the outmost scope as possible.
bool
move_alloc_out_
{
true
};
// The buffer assignment map
std
::
unordered_map
<
TensorKey
,
BufferEntry
>
buf_map_
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
storage_scope_
;
// The current thread scope.
std
::
vector
<
ThreadScope
>
curr_thread_scope_
;
// The allocations by rank
std
::
unordered_map
<
StorageScope
,
std
::
vector
<
Stmt
>
>
allocs_
;
};
Stmt
StorageFlatten
(
Stmt
stmt
,
Map
<
Tensor
,
Buffer
>
extern_buffer
)
{
stmt
=
StorageFlattener
(
extern_buffer
).
Flatten
(
stmt
);
stmt
=
StorageFlattener
(
extern_buffer
).
Mutate
(
stmt
);
return
stmt
;
}
...
...
src/runtime/thread_storage_scope.h
View file @
b8f0ec50
...
...
@@ -62,7 +62,11 @@ struct ThreadScope {
*/
static
ThreadScope
make
(
const
std
::
string
&
s
)
{
ThreadScope
r
;
if
(
s
.
compare
(
0
,
9
,
"blockIdx."
)
==
0
)
{
if
(
s
==
"vthread"
)
{
// virtual thread at the same level as local
r
.
rank
=
1
;
r
.
dim_index
=
-
1
;
}
else
if
(
s
.
compare
(
0
,
9
,
"blockIdx."
)
==
0
)
{
r
.
rank
=
0
;
r
.
dim_index
=
static_cast
<
int
>
(
s
[
9
]
-
'x'
);
}
else
if
(
s
.
compare
(
0
,
10
,
"threadIdx."
)
==
0
)
{
...
...
src/schedule/schedule_ops.cc
View file @
b8f0ec50
...
...
@@ -203,18 +203,27 @@ MakeLoopNest(const Stage& sch,
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
new_value
,
no_op
));
}
}
else
if
(
iv
->
thread_tag
==
"vthread"
)
{
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK
(
is_zero
(
dom
->
min
));
CHECK
(
is_positive_const
(
dom
->
extent
));
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
ir
::
attr
::
virtual_thread
,
dom
->
extent
,
no_op
));
value_map
[
iv
]
=
var
;
}
else
{
// Always restrict threaded IterVar to starts from 0.
CHECK
(
is_zero
(
dom
->
min
));
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"thread_extent"
,
dom
->
extent
,
no_op
));
AttrStmt
::
make
(
iv
,
ir
::
attr
::
thread_extent
,
dom
->
extent
,
no_op
));
value_map
[
iv
]
=
var
;
}
if
(
!
reduce_init_loop
)
{
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"scope"
,
iv
->
var
,
no_op
));
AttrStmt
::
make
(
iv
,
ir
::
attr
::
scope
,
iv
->
var
,
no_op
));
}
}
// message passing to get offset of root iter vars.
...
...
tests/python/unittest/test_pass_virtual_thread.py
0 → 100644
View file @
b8f0ec50
import
tvm
def
test_virtual_thread
():
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
m
,
),
name
=
'A'
)
A1
=
tvm
.
compute
((
m
,),
lambda
i
:
A
[
i
],
name
=
'A1'
)
A2
=
tvm
.
compute
((
m
,),
lambda
i
:
A1
[
i
]
+
3
,
name
=
'A2'
)
s
=
tvm
.
Schedule
(
A2
.
op
)
vx
=
tvm
.
IterVar
((
0
,
2
),
"vx"
,
thread_tag
=
"vthread"
)
xo
,
xi
=
s
[
A2
]
.
split
(
A2
.
op
.
axis
[
0
],
outer
=
vx
)
xo
,
xi
=
s
[
A2
]
.
split
(
xi
,
8
)
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
A2b
=
tvm
.
Buffer
(
A2
.
shape
,
A2
.
dtype
,
name
=
'A2'
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
A2
:
A2b
})
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
stmt
=
tvm
.
ir_pass
.
InjectVirtualThread
(
stmt
)
print
(
stmt
)
if
__name__
==
"__main__"
:
test_virtual_thread
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment