Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
154104b3
Commit
154104b3
authored
Apr 19, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Remap thread axis. (#1122)
parent
d0f40112
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
118 additions
and
2 deletions
+118
-2
include/tvm/build_module.h
+6
-0
include/tvm/ir_pass.h
+14
-0
python/tvm/build_module.py
+8
-0
python/tvm/target.py
+1
-0
src/api/api_pass.cc
+1
-0
src/codegen/build_module.cc
+2
-0
src/pass/remap_thread_axis.cc
+83
-0
tests/python/integration/test_ewise.py
+3
-2
No files found.
include/tvm/build_module.h
View file @
154104b3
...
...
@@ -32,6 +32,11 @@ class TargetNode : public Node {
int
max_num_threads
=
1
;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int
thread_warp_size
=
1
;
/*!
* \brief The thread index that is the lowest(correspond to warp)
* In cuda it is threadIdx.x, but can be different in some platform.
*/
int
thread_warp_index
=
0
;
/*! \brief Keys for this target */
Array
<
Expr
>
keys_array
;
/*! \brief Options for this target */
...
...
@@ -48,6 +53,7 @@ class TargetNode : public Node {
v
->
Visit
(
"device_type"
,
&
device_type
);
v
->
Visit
(
"max_num_threads"
,
&
max_num_threads
);
v
->
Visit
(
"thread_warp_size"
,
&
thread_warp_size
);
v
->
Visit
(
"thread_warp_index"
,
&
thread_warp_index
);
v
->
Visit
(
"keys_array"
,
&
keys_array
);
v
->
Visit
(
"options_array"
,
&
options_array
);
v
->
Visit
(
"libs_array"
,
&
libs_array
);
...
...
include/tvm/ir_pass.h
View file @
154104b3
...
...
@@ -417,6 +417,20 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
LoweredFunc
LowerWarpMemory
(
LoweredFunc
f
,
int
warp_size
);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \param f The device function to be lowered.
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
LoweredFunc
RemapThreadAxis
(
LoweredFunc
f
,
Map
<
Expr
,
IterVar
>
axis_map
);
/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
...
...
python/tvm/build_module.py
View file @
154104b3
...
...
@@ -98,6 +98,7 @@ class DumpIR(object):
schedule
.
ScheduleOps
=
self
.
_old_sgpass
DumpIR
.
scope_level
-=
1
@register_node
class
BuildConfig
(
NodeBase
):
"""Configuration scope to set a build config option.
...
...
@@ -469,6 +470,13 @@ def build(sch,
for
i
,
func
in
enumerate
(
fdevice
):
warp_size
=
target
.
thread_warp_size
fdevice
[
i
]
=
ir_pass
.
LowerWarpMemory
(
func
,
warp_size
)
warp_index
=
target
.
thread_warp_index
if
warp_index
!=
0
:
assert
warp_index
==
2
# swap z and x
tmap
=
{
api
.
convert
(
"threadIdx.z"
):
api
.
thread_axis
(
"threadIdx.x"
),
api
.
convert
(
"threadIdx.x"
):
api
.
thread_axis
(
"threadIdx.z"
)}
fdevice
[
i
]
=
ir_pass
.
RemapThreadAxis
(
func
,
tmap
)
if
"gpu"
in
target
.
keys
and
not
fdevice
:
warnings
.
warn
(
...
...
python/tvm/target.py
View file @
154104b3
...
...
@@ -109,6 +109,7 @@ class Target(NodeBase):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
_api_internal
.
_ExitTargetScope
()
@register_node
class
GenericFunc
(
NodeBase
):
"""GenericFunc node reference. This represents a generic function
...
...
src/api/api_pass.cc
View file @
154104b3
...
...
@@ -126,6 +126,7 @@ REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1
(
NarrowChannelAccess
);
REGISTER_PASS2
(
LowerThreadAllreduce
);
REGISTER_PASS2
(
LowerWarpMemory
);
REGISTER_PASS2
(
RemapThreadAxis
);
REGISTER_PASS2
(
LowerIntrin
);
REGISTER_PASS1
(
LowerTVMBuiltin
);
REGISTER_PASS1
(
CombineContextCall
);
...
...
src/codegen/build_module.cc
View file @
154104b3
...
...
@@ -78,6 +78,8 @@ Target CreateTarget(const std::string& target_name,
t
->
max_num_threads
=
256
;
if
(
t
->
device_name
==
"intel_gpu"
)
{
t
->
thread_warp_size
=
16
;
// use threadIdx.z for index
t
->
thread_warp_index
=
2
;
}
}
else
if
(
target_name
==
"metal"
||
target_name
==
"vulkan"
)
{
if
(
target_name
==
"metal"
)
{
...
...
src/pass/remap_thread_axis.cc
0 → 100644
View file @
154104b3
/*!
* Copyright (c) 2018 by Contributors
* \file remap_thread_axis.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
namespace
tvm
{
namespace
ir
{
// Mutator to change the read pattern
class
ThreadAxisRewriter
:
private
IRMutator
{
public
:
explicit
ThreadAxisRewriter
(
const
std
::
unordered_map
<
std
::
string
,
IterVar
>&
tmap
)
:
tmap_
(
tmap
)
{
}
Stmt
Rewrite
(
Stmt
stmt
)
{
return
Mutate
(
stmt
);
}
private
:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
stmt
)
final
{
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
auto
it
=
tmap_
.
find
(
iv
->
thread_tag
);
if
(
it
!=
tmap_
.
end
())
{
const
IterVar
&
new_iv
=
it
->
second
;
const
Variable
*
v
=
iv
->
var
.
get
();
if
(
!
vmap_
.
count
(
v
))
{
vmap_
[
v
]
=
new_iv
->
var
;
}
else
{
CHECK
(
vmap_
[
v
].
same_as
(
new_iv
->
var
));
}
Stmt
body
=
this
->
Mutate
(
op
->
body
);
return
AttrStmt
::
make
(
new_iv
,
op
->
attr_key
,
op
->
value
,
body
);
}
}
return
IRMutator
::
Mutate_
(
op
,
stmt
);
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
expr
)
final
{
auto
it
=
vmap_
.
find
(
op
);
if
(
it
!=
vmap_
.
end
())
return
it
->
second
;
return
IRMutator
::
Mutate_
(
op
,
expr
);
}
// The thread map
const
std
::
unordered_map
<
std
::
string
,
IterVar
>&
tmap_
;
// variable map
std
::
unordered_map
<
const
Variable
*
,
Var
>
vmap_
;
};
LoweredFunc
RemapThreadAxis
(
LoweredFunc
f
,
Map
<
Expr
,
IterVar
>
thread_map
)
{
std
::
unordered_map
<
std
::
string
,
IterVar
>
tmap
;
for
(
const
auto
&
kv
:
thread_map
)
{
const
StringImm
*
str
=
kv
.
first
.
as
<
StringImm
>
();
CHECK
(
str
!=
nullptr
);
tmap
[
str
->
value
]
=
kv
.
second
;
}
CHECK_EQ
(
f
->
func_type
,
kDeviceFunc
);
auto
n
=
std
::
make_shared
<
LoweredFuncNode
>
(
*
f
.
operator
->
());
// replace the thread axis
for
(
size_t
i
=
0
;
i
<
n
->
thread_axis
.
size
();
++
i
)
{
auto
it
=
tmap
.
find
(
n
->
thread_axis
[
i
]
->
thread_tag
);
if
(
it
!=
tmap
.
end
())
{
n
->
thread_axis
.
Set
(
i
,
it
->
second
);
}
}
n
->
body
=
ThreadAxisRewriter
(
tmap
).
Rewrite
(
n
->
body
);
return
LoweredFunc
(
n
);
}
}
// namespace ir
}
// namespace tvm
tests/python/integration/test_ewise.py
View file @
154104b3
...
...
@@ -34,9 +34,10 @@ def test_exp():
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
np
.
exp
(
a
.
asnumpy
()),
rtol
=
1e-5
)
check_device
(
"opencl -device=intel_gpu"
)
check_device
(
"cuda"
,
"llvm"
)
check_device
(
"vulkan"
)
check_device
(
"opencl"
)
def
test_log_pow_llvm
():
...
...
@@ -196,8 +197,8 @@ def try_warp_memory():
if
__name__
==
"__main__"
:
test_exp
()
try_warp_memory
()
test_add
()
test_log_pow_llvm
()
test_exp
()
test_popcount
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment