Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8e51af2f
Commit
8e51af2f
authored
Jul 16, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 16, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] CombineContextCall (#255)
parent
36f20b54
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
198 additions
and
1 deletions
+198
-1
include/tvm/ir.h
+12
-0
include/tvm/ir_pass.h
+20
-0
python/tvm/build.py
+1
-0
python/tvm/intrin.py
+25
-1
src/api/api_pass.cc
+1
-0
src/pass/combine_context_call.cc
+98
-0
src/pass/ir_deep_compare.cc
+12
-0
tests/python/unittest/test_pass_combine_context_call.py
+29
-0
No files found.
include/tvm/ir.h
View file @
8e51af2f
...
...
@@ -302,6 +302,18 @@ constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
*/
constexpr
const
char
*
tvm_call_packed
=
"tvm_call_packed"
;
/*!
* \brief See pesudo code
* Mark the content as thread local context, can get optimized
* by only call the call once at thread start.
*
* Do not allow nesting(getting a thread context from another).
*
* Handle tvm_thread_context(Expr call) {
* return call;
* }
*/
constexpr
const
char
*
tvm_thread_context
=
"tvm_thread_context"
;
/*!
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
...
...
include/tvm/ir_pass.h
View file @
8e51af2f
...
...
@@ -61,6 +61,19 @@ bool Equal(const Expr& lhs, const Expr& rhs);
bool
Equal
(
const
Stmt
&
lhs
,
const
Stmt
&
rhs
);
/*!
* \brief Deep compare lhs and rhs.
*
* If you only want equality comparison, use Equal
* which will also tie definitions. The compare mode
* will give order of expression in total order.
*
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
int
Compare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
...
...
@@ -315,6 +328,13 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
LoweredFunc
LowerPackedCall
(
LoweredFunc
f
);
/*!
* \brief Combine context function calls.
* \param f The host function to be lowered.
* \return Transformed function.
*/
LoweredFunc
CombineContextCall
(
LoweredFunc
f
);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
...
...
python/tvm/build.py
View file @
8e51af2f
...
...
@@ -321,6 +321,7 @@ def build(sch,
device_type
=
ndarray
.
context
(
device
,
0
)
.
device_type
fhost
=
[
ir_pass
.
BindDeviceType
(
x
,
device_type
)
for
x
in
fhost
]
fhost
=
[
ir_pass
.
LowerPackedCall
(
x
)
for
x
in
fhost
]
fhost
=
[
ir_pass
.
CombineContextCall
(
x
)
for
x
in
fhost
]
if
fdevice
:
if
not
target_host
:
...
...
python/tvm/intrin.py
View file @
8e51af2f
...
...
@@ -89,7 +89,7 @@ def call_pure_extern(dtype, func_name, *args):
The data type of the result.
func_name: str
The
intrinsic
function name.
The
extern
function name.
args : list
Positional arguments.
...
...
@@ -102,6 +102,30 @@ def call_pure_extern(dtype, func_name, *args):
return
_make
.
Call
(
dtype
,
func_name
,
convert
(
args
),
_Call
.
PureExtern
,
None
,
0
)
def
call_extern
(
dtype
,
func_name
,
*
args
):
"""Build expression by calling a extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The extern function name.
args : list
Positional arguments.
Returns
-------
call : Expr
The call expression.
"""
return
_make
.
Call
(
dtype
,
func_name
,
convert
(
args
),
_Call
.
Extern
,
None
,
0
)
def
exp
(
x
):
"""Take exponetial of input x.
...
...
src/api/api_pass.cc
View file @
8e51af2f
...
...
@@ -102,5 +102,6 @@ REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2
(
LowerThreadAllreduce
);
REGISTER_PASS2
(
LowerIntrin
);
REGISTER_PASS1
(
LowerPackedCall
);
REGISTER_PASS1
(
CombineContextCall
);
}
// namespace ir
}
// namespace tvm
src/pass/combine_context_call.cc
0 → 100644
View file @
8e51af2f
/*!
* Copyright (c) 2017 by Contributors
* Combine calls into context related function into one.
*
* \file combine_context_call.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <map>
namespace
tvm
{
namespace
ir
{
// Calculate the statistics of packed function.
// These information are needed during codegen.
class
ContextCallCombiner
final
:
public
IRMutator
{
public
:
struct
CompareExpr
{
bool
operator
()(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
const
{
return
Compare
(
lhs
,
rhs
)
<
0
;
}
};
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_thread_context
))
{
CHECK_EQ
(
op
->
args
.
size
(),
1U
);
Expr
ctx
=
op
->
args
[
0
];
auto
it
=
ctx_map_
.
find
(
ctx
);
if
(
it
!=
ctx_map_
.
end
())
{
return
it
->
second
;
}
else
{
CHECK
(
ctx
.
type
().
is_handle
());
std
::
string
name
;
if
(
const
Call
*
call
=
ctx
.
as
<
Call
>
())
{
name
=
call
->
name
+
"_cache"
;
}
else
{
name
=
"ctx_cache_"
;
}
Var
ctx_var
(
name
,
ctx
.
type
());
ctx_map_
[
ctx
]
=
ctx_var
;
return
ctx_var
;
}
}
else
{
return
IRMutator
::
Mutate_
(
op
,
e
);
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
// Map of comparison expression to variable
std
::
map
<
Expr
,
Var
,
CompareExpr
>
temp
;
std
::
swap
(
temp
,
ctx_map_
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
std
::
swap
(
temp
,
ctx_map_
);
return
BuildContext
(
temp
,
stmt
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
for_type
==
ForType
::
Parallel
)
{
// Map of comparison expression to variable
std
::
map
<
Expr
,
Var
,
CompareExpr
>
temp
;
std
::
swap
(
temp
,
ctx_map_
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
std
::
swap
(
temp
,
ctx_map_
);
return
BuildContext
(
temp
,
stmt
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Combine
(
Stmt
stmt
)
{
return
BuildContext
(
ctx_map_
,
this
->
Mutate
(
stmt
));
}
private
:
static
Stmt
BuildContext
(
const
std
::
map
<
Expr
,
Var
,
CompareExpr
>&
cmap
,
Stmt
body
)
{
for
(
const
auto
&
kv
:
cmap
)
{
body
=
LetStmt
::
make
(
kv
.
second
,
kv
.
first
,
body
);
}
return
body
;
}
// Map of comparison expression to variable
std
::
map
<
Expr
,
Var
,
CompareExpr
>
ctx_map_
;
};
LoweredFunc
CombineContextCall
(
LoweredFunc
f
)
{
auto
n
=
std
::
make_shared
<
LoweredFuncNode
>
(
*
f
.
operator
->
());
n
->
body
=
ContextCallCombiner
().
Combine
(
n
->
body
);
return
LoweredFunc
(
n
);
}
}
// namespace ir
}
// namespace tvm
src/pass/ir_deep_compare.cc
View file @
8e51af2f
...
...
@@ -35,8 +35,15 @@ class IRDeepCompare :
return
order_
==
0
;
}
int
Compare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
tie_def_
=
false
;
VisitExpr
(
lhs
,
rhs
);
return
order_
;
}
void
VisitExpr
(
const
Expr
&
n
,
const
Expr
&
other
)
override
{
if
(
order_
!=
0
)
return
;
if
(
n
.
same_as
(
other
))
return
;
if
(
CompareValue
(
n
->
type_index
(),
other
->
type_index
())
!=
0
)
return
;
if
(
CompareType
(
n
.
type
(),
other
.
type
())
!=
0
)
return
;
ExprComparator
::
VisitExpr
(
n
,
other
);
...
...
@@ -44,6 +51,7 @@ class IRDeepCompare :
void
VisitStmt
(
const
Stmt
&
n
,
const
Stmt
&
other
)
override
{
if
(
order_
!=
0
)
return
;
if
(
n
.
same_as
(
other
))
return
;
if
(
CompareValue
(
n
->
type_index
(),
other
->
type_index
())
!=
0
)
return
;
StmtComparator
::
VisitStmt
(
n
,
other
);
}
...
...
@@ -413,5 +421,9 @@ bool Equal(const Expr& lhs, const Expr& rhs) {
return
IRDeepCompare
().
Equal
(
lhs
,
rhs
);
}
int
Compare
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
return
IRDeepCompare
().
Compare
(
lhs
,
rhs
);
}
}
// namespace ir
}
// namespace tvm
tests/python/unittest/test_pass_combine_context_call.py
0 → 100644
View file @
8e51af2f
import
tvm
def
test_for
():
dev_type
=
tvm
.
var
(
"dev_type"
)
def
device_context
(
dev_id
):
ctx
=
tvm
.
call_extern
(
"handle"
,
"device_context"
,
dev_type
,
dev_id
)
return
tvm
.
make
.
Call
(
"handle"
,
"tvm_thread_context"
,
[
ctx
],
tvm
.
expr
.
Call
.
Intrinsic
,
None
,
0
)
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
A
=
ib
.
allocate
(
"float32"
,
n
,
name
=
"A"
,
scope
=
"global"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
ib
.
emit
(
tvm
.
call_extern
(
"int32"
,
"fadd"
,
device_context
(
0
),
A
))
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
ib
.
emit
(
tvm
.
call_extern
(
"int32"
,
"fadd"
,
device_context
(
1
),
A
))
ib
.
emit
(
tvm
.
call_extern
(
"int32"
,
"fadd"
,
device_context
(
0
),
A
))
body
=
ib
.
get
()
f
=
tvm
.
ir_pass
.
MakeAPI
(
body
,
"func"
,
[
dev_type
,
n
],
2
,
True
)
f
=
tvm
.
ir_pass
.
CombineContextCall
(
f
)
assert
f
.
body
.
value
.
dtype
==
"handle"
assert
f
.
body
.
body
.
value
.
dtype
==
"handle"
if
__name__
==
"__main__"
:
test_for
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment