Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5072efae
Commit
5072efae
authored
Sep 02, 2017
by
Tianqi Chen
Committed by
GitHub
Sep 02, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Improve vthread injection. (#411)
parent
b0d9f299
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
111 additions
and
30 deletions
+111
-30
include/tvm/ir.h
+5
-0
src/op/op_util.cc
+2
-1
src/pass/inject_double_buffer.cc
+6
-6
src/pass/inject_virtual_thread.cc
+64
-23
src/pass/lower_tvm_builtin.cc
+2
-0
tests/python/unittest/test_pass_inject_vthread.py
+32
-0
No files found.
include/tvm/ir.h
View file @
5072efae
...
...
@@ -257,6 +257,11 @@ constexpr const char* tvm_if_then_else = "tvm_if_then_else";
*/
constexpr
const
char
*
tvm_access_ptr
=
"tvm_access_ptr"
;
/*!
* \brief Return a unique context id, used for hint of workspace separation.
* Different context id ganrantees not having overlapping workspace.
*/
constexpr
const
char
*
tvm_context_id
=
"tvm_context_id"
;
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
...
...
src/op/op_util.cc
View file @
5072efae
...
...
@@ -106,7 +106,8 @@ MakeLoopNest(const Stage& stage,
it_attr
->
prefetch_offset
[
j
],
no_op
));
}
}
}
else
if
(
bind_iv
->
thread_tag
==
"vthread"
)
{
}
else
if
(
bind_iv
->
thread_tag
==
"vthread"
||
bind_iv
->
thread_tag
==
"cthread"
)
{
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK
(
is_zero
(
dom
->
min
));
...
...
src/pass/inject_double_buffer.cc
View file @
5072efae
...
...
@@ -69,7 +69,7 @@ class DoubleBufferInjector : public IRMutator {
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
dbuffer_info_
.
end
())
{
it
->
second
.
s
ize
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
);
it
->
second
.
s
tride
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
)
*
op
->
type
.
lanes
(
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Allocate
>
();
Array
<
Expr
>
new_extents
{
make_const
(
op
->
extents
[
0
].
type
(),
2
)};
...
...
@@ -126,10 +126,10 @@ class DoubleBufferInjector : public IRMutator {
if
(
it
!=
dbuffer_info_
.
end
())
{
const
StorageEntry
&
e
=
it
->
second
;
CHECK
(
in_double_buffer_scope_
);
CHECK
(
e
.
s
iz
e
.
defined
());
CHECK
(
e
.
s
trid
e
.
defined
());
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
e
.
switch_write_var
*
e
.
s
iz
e
+
op
->
index
,
e
.
switch_write_var
*
e
.
s
trid
e
+
op
->
index
,
op
->
predicate
);
}
else
{
return
stmt
;
...
...
@@ -142,11 +142,11 @@ class DoubleBufferInjector : public IRMutator {
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
dbuffer_info_
.
end
())
{
const
StorageEntry
&
e
=
it
->
second
;
CHECK
(
e
.
s
iz
e
.
defined
());
CHECK
(
e
.
s
trid
e
.
defined
());
CHECK
(
e
.
switch_read_var
.
defined
());
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
e
.
switch_read_var
*
e
.
s
iz
e
+
op
->
index
,
e
.
switch_read_var
*
e
.
s
trid
e
+
op
->
index
,
op
->
predicate
);
}
else
{
return
expr
;
...
...
@@ -194,7 +194,7 @@ class DoubleBufferInjector : public IRMutator {
// Storage entry for those who need double buffering.
struct
StorageEntry
{
// The size of the buffer
Expr
s
iz
e
;
Expr
s
trid
e
;
// The loop we need
const
For
*
loop
{
nullptr
};
// The switch variable.
...
...
src/pass/inject_virtual_thread.cc
View file @
5072efae
...
...
@@ -130,22 +130,29 @@ class VTInjector : public IRMutator {
// constructor
VTInjector
(
Var
var
,
int
num_threads
,
std
::
unordered_set
<
const
Variable
*>
touched_var
)
:
var_
(
var
),
num_threads_
(
num_threads
),
touched_var_
(
touched_var
)
{
const
std
::
unordered_set
<
const
Variable
*>&
touched_var
,
bool
allow_share
)
:
var_
(
var
),
num_threads_
(
num_threads
),
touched_var_
(
touched_var
),
allow_share_
(
allow_share
)
{
}
// Inject VTLoop when needed.
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
!
visit_touched_var_
)
<<
stmt
->
type_key
()
<<
stmt
;
stmt
=
IRMutator
::
Mutate
(
stmt
);
if
(
visit_touched_var_
)
{
if
(
!
vt_loop_injected_
)
return
InjectVTLoop
(
stmt
,
false
);
if
(
visit_touched_var_
||
trigger_base_inject_
)
{
if
(
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
stmt
,
false
);
}
visit_touched_var_
=
false
;
trigger_base_inject_
=
false
;
}
return
stmt
;
}
// Variable
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
CHECK
(
!
alloc_remap_
.
count
(
op
))
<<
"Buffer address may get rewritten in virtual thread"
;
if
(
touched_var_
.
count
(
op
))
{
visit_touched_var_
=
true
;
}
...
...
@@ -161,8 +168,8 @@ class VTInjector : public IRMutator {
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
()))
{
visit_touched_var_
=
true
;
}
auto
it
=
touched_alloc
_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
touched_alloc
_
.
end
())
{
auto
it
=
alloc_remap
_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
alloc_remap
_
.
end
())
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
RewriteIndex
(
op
->
index
,
it
->
second
),
op
->
predicate
);
...
...
@@ -170,6 +177,34 @@ class VTInjector : public IRMutator {
return
expr
;
}
}
// Expression.
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
))
{
CHECK_EQ
(
op
->
args
.
size
(),
5U
);
Type
dtype
=
op
->
args
[
0
].
type
();
const
Variable
*
buffer
=
op
->
args
[
1
].
as
<
Variable
>
();
auto
it
=
alloc_remap_
.
find
(
buffer
);
if
(
it
==
alloc_remap_
.
end
())
return
IRMutator
::
Mutate_
(
op
,
e
);
visit_touched_var_
=
true
;
Expr
offset
=
Mutate
(
op
->
args
[
2
]);
Expr
extent
=
Mutate
(
op
->
args
[
3
]);
Expr
stride
=
arith
::
ComputeExpr
<
Div
>
(
it
->
second
,
make_const
(
offset
.
type
(),
dtype
.
lanes
()));
offset
=
stride
*
var_
+
offset
;
return
Call
::
make
(
op
->
type
,
op
->
name
,
{
op
->
args
[
0
],
op
->
args
[
1
],
offset
,
extent
,
op
->
args
[
4
]},
op
->
call_type
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_context_id
))
{
return
allow_share_
?
e
:
var_
;
}
else
{
return
IRMutator
::
Mutate_
(
op
,
e
);
}
}
Stmt
Mutate_
(
const
Evaluate
*
op
,
const
Stmt
&
s
)
final
{
trigger_base_inject_
=
!
allow_share_
;
return
IRMutator
::
Mutate_
(
op
,
s
);
}
// Store
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
...
...
@@ -177,8 +212,9 @@ class VTInjector : public IRMutator {
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
()))
{
visit_touched_var_
=
true
;
}
auto
it
=
touched_alloc_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
touched_alloc_
.
end
())
{
trigger_base_inject_
=
!
allow_share_
;
auto
it
=
alloc_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
alloc_remap_
.
end
())
{
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
RewriteIndex
(
op
->
index
,
it
->
second
),
...
...
@@ -190,7 +226,10 @@ class VTInjector : public IRMutator {
// Attribute
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Expr
value
=
Mutate
(
op
->
value
);
if
(
visit_touched_var_
)
{
if
(
visit_touched_var_
&&
!
vt_loop_injected_
)
{
return
InjectVTLoop
(
s
,
true
);
}
else
if
(
!
allow_share_
&&
!
vt_loop_injected_
&&
op
->
attr_key
==
attr
::
coproc_uop_scope
)
{
return
InjectVTLoop
(
s
,
true
);
}
else
{
Stmt
body
=
Mutate
(
op
->
body
);
...
...
@@ -299,24 +338,19 @@ class VTInjector : public IRMutator {
visit_touched_var_
=
false
;
Stmt
body
;
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
()))
{
// always rewrite if not allow sharing.
if
(
touched_var_
.
count
(
op
->
buffer_var
.
get
())
||
!
allow_share_
)
{
// place v on highest dimension.
Expr
stride
=
extents
[
0
];
for
(
size_t
i
=
1
;
i
<
extents
.
size
();
++
i
)
{
stride
=
arith
::
ComputeExpr
<
Mul
>
(
stride
,
extents
[
i
]);
}
if
(
op
->
type
.
lanes
()
!=
0
)
{
stride
=
stride
*
op
->
type
.
lanes
();
}
Expr
stride
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
)
*
op
->
type
.
lanes
();
Array
<
Expr
>
other
;
other
.
push_back
(
num_threads_
);
other
.
push_back
(
make_const
(
op
->
extents
[
0
].
type
(),
num_threads_
)
);
for
(
Expr
e
:
extents
)
{
other
.
push_back
(
e
);
}
extents
=
other
;
changed
=
true
;
// mark this buffer get touched.
touched_alloc
_
[
op
->
buffer_var
.
get
()]
=
stride
;
alloc_remap
_
[
op
->
buffer_var
.
get
()]
=
stride
;
// Mutate the body.
body
=
Mutate
(
op
->
body
);
}
else
{
...
...
@@ -340,6 +374,7 @@ class VTInjector : public IRMutator {
CHECK
(
!
vt_loop_injected_
);
// reset the flags
visit_touched_var_
=
false
;
trigger_base_inject_
=
false
;
vt_loop_injected_
=
true
;
if
(
before_mutation
)
{
stmt
=
this
->
Mutate
(
stmt
);
...
...
@@ -359,7 +394,8 @@ class VTInjector : public IRMutator {
// insert a for loop
Var
idx
(
var_
->
name_hint
+
".s"
,
var_
->
type
);
stmt
=
Substitute
(
stmt
,
{{
var_
,
idx
}});
return
For
::
make
(
idx
,
0
,
num_threads_
,
return
For
::
make
(
idx
,
make_zero
(
idx
.
type
()),
make_const
(
idx
.
type
(),
num_threads_
),
ForType
::
Serial
,
DeviceAPI
::
None
,
stmt
);
}
}
...
...
@@ -373,12 +409,16 @@ class VTInjector : public IRMutator {
bool
vt_loop_injected_
{
false
};
// whether current expression get touched.
bool
visit_touched_var_
{
false
};
// Trigger base stmt
bool
trigger_base_inject_
{
false
};
// the counter of loops in after mutation.
int
max_loop_depth_
{
0
};
// The variables that get touched.
std
::
unordered_set
<
const
Variable
*>
touched_var_
;
const
std
::
unordered_set
<
const
Variable
*>&
touched_var_
;
// Whether allow shareding.
bool
allow_share_
;
// The allocations that get touched -> extent
std
::
unordered_map
<
const
Variable
*
,
Expr
>
touched_alloc
_
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
alloc_remap
_
;
};
...
...
@@ -389,10 +429,11 @@ class VirtualThreadInjector : public IRMutator {
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
->
attr_key
==
attr
::
virtual_thread
)
{
IterVar
iv
(
op
->
node
.
node_
);
bool
allow_share
=
iv
->
thread_tag
==
"vthread"
;
int
nthread
=
static_cast
<
int
>
(
op
->
value
.
as
<
IntImm
>
()
->
value
);
VarTouchedAnalysis
vs
;
auto
touched
=
vs
.
TouchedVar
(
op
->
body
,
iv
->
var
.
get
());
VTInjector
injecter
(
iv
->
var
,
nthread
,
touched
);
VTInjector
injecter
(
iv
->
var
,
nthread
,
touched
,
allow_share
);
return
injecter
.
Mutate
(
op
->
body
);
}
else
{
return
stmt
;
...
...
src/pass/lower_tvm_builtin.cc
View file @
5072efae
...
...
@@ -140,6 +140,8 @@ class BuiltinLower : public IRMutator {
return
MakeShape
(
op
,
e
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_stack_make_array
))
{
return
MakeArray
(
op
,
e
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_context_id
))
{
return
make_zero
(
op
->
type
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
e
);
}
...
...
tests/python/unittest/test_pass_inject_vthread.py
0 → 100644
View file @
5072efae
import
tvm
def
test_vthread
():
dtype
=
'int64'
n
=
100
m
=
4
nthread
=
2
def
get_vthread
(
name
):
tx
=
tvm
.
thread_axis
(
name
)
ty
=
tvm
.
thread_axis
(
name
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
C
=
ib
.
pointer
(
"float32"
,
name
=
"C"
)
with
ib
.
for_range
(
0
,
n
)
as
i
:
ib
.
scope_attr
(
tx
,
"virtual_thread"
,
nthread
)
ib
.
scope_attr
(
ty
,
"virtual_thread"
,
nthread
)
B
=
ib
.
allocate
(
"float32"
,
m
,
name
=
"B"
,
scope
=
"shared"
)
B
[
i
]
=
A
[
i
*
nthread
+
tx
]
bbuffer
=
tvm
.
decl_buffer
((
m
,),
dtype
=
B
.
dtype
,
data
=
B
.
asnode
())
ib
.
emit
(
tvm
.
call_extern
(
"int32"
,
"Run"
,
bbuffer
.
access_ptr
(
"r"
),
tvm
.
call_pure_intrin
(
"int32"
,
"tvm_context_id"
)))
C
[
i
*
nthread
+
tx
]
=
B
[
i
]
+
1
return
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
InjectVirtualThread
(
get_vthread
(
"vthread"
))
assert
stmt
.
body
.
body
.
extents
[
0
]
.
value
==
2
stmt
=
tvm
.
ir_pass
.
InjectVirtualThread
(
get_vthread
(
"cthread"
))
assert
len
(
stmt
.
body
.
body
.
extents
)
==
3
if
__name__
==
"__main__"
:
test_vthread
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment