Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4578048c
Commit
4578048c
authored
Aug 30, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 30, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] IRTransform to enable IR pass proptype in python (#401)
parent
8ef26606
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
110 additions
and
1 deletions
+110
-1
include/tvm/ir_mutator.h
+19
-0
src/api/api_pass.cc
+2
-0
src/api/dsl_api.cc
+0
-0
src/pass/coproc_sync.cc
+0
-1
src/pass/ir_mutator.cc
+54
-0
tests/python/unittest/test_arith_detect_linear_equation.py
+6
-0
tests/python/unittest/test_pass_ir_transform.py
+29
-0
No files found.
include/tvm/ir_mutator.h
View file @
4578048c
...
...
@@ -102,6 +102,25 @@ class IRMutator {
virtual
Expr
Mutate_
(
const
Shuffle
*
op
,
const
Expr
&
e
);
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt
IRTransform
(
const
Stmt
&
node
,
const
runtime
::
PackedFunc
&
preorder
,
const
runtime
::
PackedFunc
&
postorder
,
const
Array
<
Expr
>&
only_enable
=
{});
}
// namespace ir
}
// namespace tvm
#endif // TVM_IR_MUTATOR_H_
src/api/api_pass.cc
View file @
4578048c
...
...
@@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/api_registry.h>
namespace
tvm
{
...
...
@@ -88,6 +89,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1
(
RewriteUnsafeSelect
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS3
(
StorageFlatten
);
REGISTER_PASS4
(
IRTransform
);
REGISTER_PASS1
(
VectorizeLoop
);
REGISTER_PASS4
(
UnrollLoop
);
REGISTER_PASS2
(
ThreadSync
);
...
...
src/api/dsl_api.cc
View file @
4578048c
src/pass/coproc_sync.cc
View file @
4578048c
...
...
@@ -11,7 +11,6 @@
#include "./ir_util.h"
#include "./storage_access.h"
namespace
tvm
{
namespace
ir
{
...
...
src/pass/ir_mutator.cc
View file @
4578048c
...
...
@@ -4,11 +4,65 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/packed_func_ext.h>
#include "./ir_util.h"
namespace
tvm
{
namespace
ir
{
class
IRTransformer
final
:
public
IRMutator
{
public
:
IRTransformer
(
const
runtime
::
PackedFunc
&
f_preorder
,
const
runtime
::
PackedFunc
&
f_postorder
,
const
std
::
unordered_set
<
uint32_t
>&
only_enable
)
:
f_preorder_
(
f_preorder
),
f_postorder_
(
f_postorder
),
only_enable_
(
only_enable
)
{
}
Stmt
Mutate
(
Stmt
stmt
)
final
{
return
MutateInternal
<
Stmt
>
(
stmt
);
}
Expr
Mutate
(
Expr
expr
)
final
{
return
MutateInternal
<
Expr
>
(
expr
);
}
private
:
template
<
typename
T
>
T
MutateInternal
(
T
node
)
{
if
(
only_enable_
.
size
()
&&
!
only_enable_
.
count
(
node
->
type_index
()))
{
return
IRMutator
::
Mutate
(
node
);
}
if
(
f_preorder_
!=
nullptr
)
{
T
pre
=
f_preorder_
(
node
);
if
(
pre
.
defined
())
return
pre
;
}
node
=
IRMutator
::
Mutate
(
node
);
if
(
f_postorder_
!=
nullptr
)
{
T
post
=
f_postorder_
(
node
);
if
(
post
.
defined
())
return
post
;
}
return
node
;
}
// The functions
const
runtime
::
PackedFunc
&
f_preorder_
;
const
runtime
::
PackedFunc
&
f_postorder_
;
// type indices enabled.
const
std
::
unordered_set
<
uint32_t
>&
only_enable_
;
};
Stmt
IRTransform
(
const
Stmt
&
ir_node
,
const
runtime
::
PackedFunc
&
f_preorder
,
const
runtime
::
PackedFunc
&
f_postorder
,
const
Array
<
Expr
>&
only_enable
)
{
std
::
unordered_set
<
uint32_t
>
only_type_index
;
for
(
Expr
s
:
only_enable
)
{
only_type_index
.
insert
(
Node
::
TypeKey2Index
(
s
.
as
<
StringImm
>
()
->
value
.
c_str
()));
}
return
IRTransformer
(
f_preorder
,
f_postorder
,
only_type_index
)
.
Mutate
(
ir_node
);
}
IRMutator
::
FMutateExpr
&
IRMutator
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
...
...
tests/python/unittest/test_arith_detect_linear_equation.py
View file @
4578048c
...
...
@@ -14,5 +14,11 @@ def test_basic():
assert
m
[
1
]
.
value
==
5
assert
tvm
.
ir_pass
.
Simplify
(
m
[
0
]
-
(
b
*
6
+
7
+
1
))
.
value
==
0
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
b
+
7
,
a
)
assert
m
[
1
]
==
b
m
=
tvm
.
arith
.
DetectLinearEquation
(
b
*
7
,
a
)
assert
m
[
1
]
.
value
==
0
if
__name__
==
"__main__"
:
test_basic
()
tests/python/unittest/test_pass_ir_transform.py
0 → 100644
View file @
4578048c
import
tvm
def
test_ir_transform
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
x
=
tvm
.
call_extern
(
"int32"
,
"TestA"
,
i
*
3
+
j
*
1
)
ib
.
emit
(
tvm
.
call_extern
(
"int32"
,
"TestB"
,
x
))
ib
.
emit
(
tvm
.
call_extern
(
"int32"
,
"TestC"
,
x
))
body
=
ib
.
get
()
def
preorder
(
op
):
if
op
.
name
==
"TestC"
:
return
tvm
.
const
(
0
,
"int32"
)
return
None
def
postorder
(
op
):
assert
isinstance
(
op
,
tvm
.
expr
.
Call
)
if
op
.
name
==
"TestA"
:
return
tvm
.
call_extern
(
"int32"
,
"TestB"
,
op
.
args
[
0
]
+
1
)
return
op
body
=
tvm
.
ir_pass
.
IRTransform
(
body
,
preorder
,
postorder
,
[
"Call"
])
stmt_list
=
tvm
.
make
.
stmt_list
(
body
.
body
.
body
)
assert
stmt_list
[
0
]
.
value
.
args
[
0
]
.
name
==
"TestB"
assert
stmt_list
[
1
]
.
value
.
value
==
0
if
__name__
==
"__main__"
:
test_ir_transform
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment