Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8e04361c
Commit
8e04361c
authored
Nov 18, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor IR Pass
parent
ff6b8d82
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
203 additions
and
153 deletions
+203
-153
HalideIR
+1
-1
include/tvm/expr.h
+1
-0
include/tvm/ir_mutator.h
+6
-13
include/tvm/ir_pass.h
+26
-0
include/tvm/ir_visitor.h
+1
-1
include/tvm/tensor.h
+7
-1
src/pass/ir_mutator.cc
+41
-67
src/pass/ir_pass.cc
+39
-0
src/pass/ir_visitor.cc
+48
-48
tests/cpp/ir_mutator_test.cc
+3
-22
tests/cpp/ir_pass_test.cc
+29
-0
tests/cpp/ir_visitor_test.cc
+1
-0
No files found.
HalideIR
@
4becbde6
Subproject commit
89b7939957d66a37dd6083ad6b09a5644e73fd8b
Subproject commit
4becbde67c8aa565941b02648cea90f50211f8dc
include/tvm/expr.h
View file @
8e04361c
...
...
@@ -27,6 +27,7 @@ using Halide::abs;
using
Halide
::
select
;
using
Halide
::
Expr
;
using
Halide
::
IR
::
FunctionBaseNode
;
using
Halide
::
Internal
::
Stmt
;
class
Var
:
public
Halide
::
VarExpr
{
...
...
include/tvm/ir_mutator.h
View file @
8e04361c
...
...
@@ -29,7 +29,7 @@ class IRMutator {
* \brief mutate expression
* \return the mutated expr
*/
virtual
Expr
m
utate
(
Expr
expr
)
{
virtual
Expr
M
utate
(
Expr
expr
)
{
static
const
FMutateExpr
&
f
=
vtable_expr
();
return
f
(
expr
,
expr
,
this
);
}
...
...
@@ -37,7 +37,7 @@ class IRMutator {
* \brief mutate expression
* \return the mutated stmt
*/
virtual
Stmt
m
utate
(
Stmt
stmt
)
{
virtual
Stmt
M
utate
(
Stmt
stmt
)
{
static
const
FMutateStmt
&
f
=
vtable_stmt
();
return
f
(
stmt
,
stmt
,
this
);
}
...
...
@@ -58,28 +58,21 @@ class IRMutator {
*/
class
IRMutatorExample
:
public
IRMutator
{
public
:
Expr
m
utate
(
Expr
expr
)
final
{
Expr
M
utate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRMutatorExample
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
m
utate
(
expr
));
f
(
expr
,
expr
,
this
)
:
IRMutator
::
M
utate
(
expr
));
}
Stmt
m
utate
(
Stmt
stmt
)
final
{
Stmt
M
utate
(
Stmt
stmt
)
final
{
static
const
FMutateStmt
&
f
=
IRMutatorExample
::
vtable_stmt
();
return
(
f
.
can_dispatch
(
stmt
)
?
f
(
stmt
,
stmt
,
this
)
:
IRMutator
::
m
utate
(
stmt
));
f
(
stmt
,
stmt
,
this
)
:
IRMutator
::
M
utate
(
stmt
));
}
// to be implemented by child class
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
};
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
);
}
// namespace ir
}
// namespace tvm
#endif // TVM_IR_MUTATOR_H_
include/tvm/ir_pass.h
0 → 100644
View file @
8e04361c
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace
tvm
{
namespace
ir
{
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
);
}
// namespace ir
}
// namespace tvm
#endif // TVM_IR_PASS_H_
include/tvm/ir_visitor.h
View file @
8e04361c
...
...
@@ -24,7 +24,7 @@ class IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual
void
v
isit
(
const
IRNodeRef
&
node
)
{
virtual
void
V
isit
(
const
IRNodeRef
&
node
)
{
static
const
FVisit
&
f
=
vtable
();
if
(
node
.
defined
())
f
(
node
,
this
);
}
...
...
include/tvm/tensor.h
View file @
8e04361c
...
...
@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
};
/*! \brief Node to represent a tensor */
class
TensorNode
:
public
Node
{
class
TensorNode
:
public
FunctionBase
Node
{
public
:
/*! \brief The shape of the tensor */
Array
<
Expr
>
shape
;
...
...
@@ -125,6 +125,12 @@ class TensorNode : public Node {
v
->
Visit
(
"dim_var"
,
&
dim_var
);
v
->
Visit
(
"source"
,
&
source
);
}
const
std
::
string
&
func_name
()
const
final
{
return
name
;
}
int
outputs
()
const
final
{
return
1
;
}
static
Tensor
make
(
Array
<
Expr
>
shape
,
std
::
string
name
,
Type
dtype
,
...
...
src/pass/ir_mutator.cc
View file @
8e04361c
...
...
@@ -8,32 +8,6 @@
namespace
tvm
{
namespace
ir
{
namespace
{
// visitor to implement apply
class
IRSubstitute
:
public
IRMutator
{
public
:
Expr
mutate
(
Expr
expr
)
final
{
const
IRNode
*
v
=
expr
.
get
();
if
(
v
!=
nullptr
)
{
auto
it
=
replacements_
.
find
(
v
);
if
(
it
!=
replacements_
.
end
())
{
return
it
->
second
;
}
}
return
IRMutator
::
mutate
(
expr
);
}
explicit
IRSubstitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
)
:
replacements_
(
replacements
)
{}
private
:
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements_
;
};
}
// namespace
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
)
{
return
IRSubstitute
(
replacements
).
mutate
(
expr
);
}
IRMutator
::
FMutateExpr
&
IRMutator
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
...
...
@@ -57,7 +31,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
Expr
old_elem
=
arr
[
i
];
Expr
new_elem
=
m
->
m
utate
(
old_elem
);
Expr
new_elem
=
m
->
M
utate
(
old_elem
);
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
new_arr
[
i
]
=
new_elem
;
}
...
...
@@ -73,8 +47,8 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
Expr
new_min
=
m
->
m
utate
(
r
->
min
);
Expr
new_extent
=
m
->
m
utate
(
r
->
extent
);
Expr
new_min
=
m
->
M
utate
(
r
->
min
);
Expr
new_extent
=
m
->
M
utate
(
r
->
extent
);
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
new_dom
[
i
]
=
Range
::
make_with_min_extent
(
new_min
,
new_extent
);
...
...
@@ -89,7 +63,7 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
RDomain
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Expr
new_source
=
m
->
m
utate
(
op
->
source
);
Expr
new_source
=
m
->
M
utate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
op
->
source
.
same_as
(
new_source
))
{
return
e
;
...
...
@@ -107,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
...
...
@@ -118,8 +92,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
// binary operator
template
<
typename
T
>
inline
Expr
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
m
utate
(
op
->
a
);
Expr
b
=
m
->
m
utate
(
op
->
b
);
Expr
a
=
m
->
M
utate
(
op
->
a
);
Expr
b
=
m
->
M
utate
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
e
;
...
...
@@ -147,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
m
utate
(
op
->
a
);
Expr
a
=
m
->
M
utate
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
return
e
;
}
else
{
...
...
@@ -155,9 +129,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
cond
=
m
->
m
utate
(
op
->
condition
);
Expr
t
=
m
->
m
utate
(
op
->
true_value
);
Expr
f
=
m
->
m
utate
(
op
->
false_value
);
Expr
cond
=
m
->
M
utate
(
op
->
condition
);
Expr
t
=
m
->
M
utate
(
op
->
true_value
);
Expr
f
=
m
->
M
utate
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
...
...
@@ -167,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
index
=
m
->
m
utate
(
op
->
index
);
Expr
index
=
m
->
M
utate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
return
e
;
}
else
{
...
...
@@ -175,8 +149,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
base
=
m
->
m
utate
(
op
->
base
);
Expr
stride
=
m
->
m
utate
(
op
->
stride
);
Expr
base
=
m
->
M
utate
(
op
->
base
);
Expr
stride
=
m
->
M
utate
(
op
->
stride
);
if
(
base
.
same_as
(
op
->
base
)
&&
stride
.
same_as
(
op
->
stride
))
{
return
e
;
...
...
@@ -185,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
...
...
@@ -202,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
body
=
m
->
m
utate
(
op
->
body
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
Expr
body
=
m
->
M
utate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
e
;
...
...
@@ -214,8 +188,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_stmt
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
...
...
@@ -224,8 +198,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Expr
message
=
m
->
m
utate
(
op
->
message
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
Expr
message
=
m
->
M
utate
(
op
->
message
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
return
s
;
...
...
@@ -234,7 +208,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
if
(
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
...
...
@@ -242,9 +216,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
min
=
m
->
m
utate
(
op
->
min
);
Expr
extent
=
m
->
m
utate
(
op
->
extent
);
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Expr
min
=
m
->
M
utate
(
op
->
min
);
Expr
extent
=
m
->
M
utate
(
op
->
extent
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
if
(
min
.
same_as
(
op
->
min
)
&&
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
...
...
@@ -255,8 +229,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
index
=
m
->
m
utate
(
op
->
index
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
Expr
index
=
m
->
M
utate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
}
else
{
...
...
@@ -276,14 +250,14 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
std
::
vector
<
Expr
>
new_extents
;
bool
all_extents_unmodified
=
true
;
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
new_extents
.
push_back
(
m
->
m
utate
(
op
->
extents
[
i
]));
new_extents
.
push_back
(
m
->
M
utate
(
op
->
extents
[
i
]));
all_extents_unmodified
&=
new_extents
[
i
].
same_as
(
op
->
extents
[
i
]);
}
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
Expr
new_expr
;
if
(
op
->
new_expr
.
defined
())
{
new_expr
=
m
->
m
utate
(
op
->
new_expr
);
new_expr
=
m
->
M
utate
(
op
->
new_expr
);
}
if
(
all_extents_unmodified
&&
body
.
same_as
(
op
->
body
)
&&
...
...
@@ -308,16 +282,16 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
Expr
old_min
=
op
->
bounds
[
i
]
->
min
;
Expr
old_extent
=
op
->
bounds
[
i
]
->
extent
;
Expr
new_min
=
m
->
m
utate
(
old_min
);
Expr
new_extent
=
m
->
m
utate
(
old_extent
);
Expr
new_min
=
m
->
M
utate
(
old_min
);
Expr
new_extent
=
m
->
M
utate
(
old_extent
);
if
(
!
new_min
.
same_as
(
old_min
))
bounds_changed
=
true
;
if
(
!
new_extent
.
same_as
(
old_extent
))
bounds_changed
=
true
;
new_bounds
.
push_back
(
Range
::
make_by_min_extent
(
new_min
,
new_extent
));
}
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
if
(
!
bounds_changed
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
))
{
...
...
@@ -328,8 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
first
=
m
->
m
utate
(
op
->
first
);
Stmt
rest
=
m
->
m
utate
(
op
->
rest
);
Stmt
first
=
m
->
M
utate
(
op
->
first
);
Stmt
rest
=
m
->
M
utate
(
op
->
rest
);
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
...
...
@@ -338,9 +312,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Stmt
then_case
=
m
->
m
utate
(
op
->
then_case
);
Stmt
else_case
=
m
->
m
utate
(
op
->
else_case
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
Stmt
then_case
=
m
->
M
utate
(
op
->
then_case
);
Stmt
else_case
=
m
->
M
utate
(
op
->
else_case
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
...
...
@@ -350,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
v
=
m
->
m
utate
(
op
->
value
);
Expr
v
=
m
->
M
utate
(
op
->
value
);
if
(
v
.
same_as
(
op
->
value
))
{
return
s
;
}
else
{
...
...
src/pass/ir_pass.cc
0 → 100644
View file @
8e04361c
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace
tvm
{
namespace
ir
{
namespace
{
// visitor to implement apply
class
IRSubstitute
:
public
IRMutator
{
public
:
Expr
Mutate
(
Expr
expr
)
final
{
const
IRNode
*
v
=
expr
.
get
();
if
(
v
!=
nullptr
)
{
auto
it
=
replacements_
.
find
(
v
);
if
(
it
!=
replacements_
.
end
())
{
return
it
->
second
;
}
}
return
IRMutator
::
Mutate
(
expr
);
}
explicit
IRSubstitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
)
:
replacements_
(
replacements
)
{}
private
:
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements_
;
};
}
// namespace
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
)
{
return
IRSubstitute
(
replacements
).
Mutate
(
expr
);
}
}
// namespace ir
}
// namespace tvm
src/pass/ir_visitor.cc
View file @
8e04361c
...
...
@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
public
:
explicit
IRApplyVisit
(
std
::
function
<
void
(
const
IRNodeRef
&
)
>
f
)
:
f_
(
f
)
{}
void
v
isit
(
const
IRNodeRef
&
node
)
final
{
void
V
isit
(
const
IRNodeRef
&
node
)
final
{
if
(
visited_
.
count
(
node
.
get
())
!=
0
)
return
;
visited_
.
insert
(
node
.
get
());
IRVisitor
::
v
isit
(
node
);
IRVisitor
::
V
isit
(
node
);
f_
(
node
);
}
...
...
@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
std
::
function
<
void
(
const
IRNodeRef
&
)
>
f_
;
std
::
unordered_set
<
const
Node
*>
visited_
;
};
}
// namespace
void
PostOrderVisit
(
const
IRNodeRef
&
node
,
std
::
function
<
void
(
const
IRNodeRef
&
)
>
fvisit
)
{
IRApplyVisit
(
fvisit
).
Visit
(
node
);
}
IRVisitor
::
FVisit
&
IRVisitor
::
vtable
()
{
// NOLINT(*)
static
FVisit
inst
;
return
inst
;
}
void
PostOrderVisit
(
const
IRNodeRef
&
node
,
std
::
function
<
void
(
const
IRNodeRef
&
)
>
fvisit
)
{
IRApplyVisit
v
(
fvisit
);
v
.
visit
(
node
);
}
// namespace to register the functors.
namespace
{
...
...
@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
inline
void
VisitArray
(
Array
<
Expr
>
arr
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
v
->
v
isit
(
arr
[
i
]);
v
->
V
isit
(
arr
[
i
]);
}
}
inline
void
VisitRDom
(
RDomain
rdom
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
v
->
v
isit
(
r
->
min
);
v
->
v
isit
(
r
->
extent
);
v
->
V
isit
(
r
->
min
);
v
->
V
isit
(
r
->
extent
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
IRVisitor
*
v
)
{
VisitRDom
(
op
->
rdom
,
v
);
v
->
v
isit
(
op
->
source
);
v
->
V
isit
(
op
->
source
);
});
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
...
...
@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
});
// binary operator
template
<
typename
T
>
inline
void
Binary
(
const
T
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
a
);
v
->
v
isit
(
op
->
b
);
v
->
V
isit
(
op
->
a
);
v
->
V
isit
(
op
->
b
);
}
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
...
...
@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
a
);
v
->
V
isit
(
op
->
a
);
})
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
condition
);
v
->
v
isit
(
op
->
true_value
);
v
->
v
isit
(
op
->
false_value
);
v
->
V
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
true_value
);
v
->
V
isit
(
op
->
false_value
);
})
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
index
);
v
->
V
isit
(
op
->
index
);
})
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
base
);
v
->
v
isit
(
op
->
stride
);
v
->
V
isit
(
op
->
base
);
v
->
V
isit
(
op
->
stride
);
})
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
})
.
set_dispatch
<
Call
>
([](
const
Call
*
op
,
IRVisitor
*
v
)
{
VisitArray
(
op
->
args
,
v
);
})
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
value
);
v
->
V
isit
(
op
->
body
);
});
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
value
);
v
->
V
isit
(
op
->
body
);
})
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
condition
);
v
->
v
isit
(
op
->
message
);
v
->
V
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
message
);
})
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
min
);
v
->
v
isit
(
op
->
extent
);
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
min
);
v
->
V
isit
(
op
->
extent
);
v
->
V
isit
(
op
->
body
);
})
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
v
isit
(
op
->
index
);
v
->
V
isit
(
op
->
value
);
v
->
V
isit
(
op
->
index
);
})
.
set_dispatch
<
Provide
>
([](
const
Provide
*
op
,
IRVisitor
*
v
)
{
VisitArray
(
op
->
args
,
v
);
...
...
@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
.
set_dispatch
<
Allocate
>
([](
const
Allocate
*
op
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
v
->
v
isit
(
op
->
extents
[
i
]);
v
->
V
isit
(
op
->
extents
[
i
]);
}
v
->
v
isit
(
op
->
body
);
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
body
);
v
->
V
isit
(
op
->
condition
);
if
(
op
->
new_expr
.
defined
())
{
v
->
v
isit
(
op
->
new_expr
);
v
->
V
isit
(
op
->
new_expr
);
}
})
.
set_dispatch
<
Free
>
(
NoOp
)
.
set_dispatch
<
Realize
>
([](
const
Realize
*
op
,
IRVisitor
*
v
)
{
// Mutate the bounds
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
v
->
v
isit
(
op
->
bounds
[
i
]
->
min
);
v
->
v
isit
(
op
->
bounds
[
i
]
->
extent
);
v
->
V
isit
(
op
->
bounds
[
i
]
->
min
);
v
->
V
isit
(
op
->
bounds
[
i
]
->
extent
);
}
v
->
v
isit
(
op
->
body
);
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
body
);
v
->
V
isit
(
op
->
condition
);
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
first
);
v
->
v
isit
(
op
->
rest
);
v
->
V
isit
(
op
->
first
);
v
->
V
isit
(
op
->
rest
);
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
condition
);
v
->
v
isit
(
op
->
then_case
);
v
->
v
isit
(
op
->
else_case
);
v
->
V
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
then_case
);
v
->
V
isit
(
op
->
else_case
);
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
});
}
// namespace
...
...
tests/cpp/ir_mutator_test.cc
View file @
8e04361c
...
...
@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
public
:
VarExpr
var
;
int
int_val
;
Expr
m
utate
(
Expr
expr
)
final
{
Expr
M
utate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRVar2Const
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
m
utate
(
expr
));
f
(
expr
,
expr
,
this
)
:
IRMutator
::
M
utate
(
expr
));
}
static
FMutateExpr
&
vtable_expr
();
};
...
...
@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
IRVar2Const
mu
;
mu
.
var
=
y
;
mu
.
int_val
=
10
;
auto
zz
=
mu
.
m
utate
(
z
);
auto
zz
=
mu
.
M
utate
(
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 10)"
);
}
TEST
(
IRMutator
,
Substitute
)
{
using
namespace
Halide
::
Internal
;
using
namespace
tvm
;
Var
x
(
"x"
),
y
;
auto
z
=
x
+
y
;
{
auto
zz
=
Substitute
({{
y
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 11)"
);
}
{
auto
zz
=
Substitute
({{
z
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"11"
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
tests/cpp/ir_pass_test.cc
0 → 100644
View file @
8e04361c
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST
(
IRPass
,
Substitute
)
{
using
namespace
Halide
::
Internal
;
using
namespace
tvm
;
Var
x
(
"x"
),
y
;
auto
z
=
x
+
y
;
{
auto
zz
=
ir
::
Substitute
({{
y
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 11)"
);
}
{
auto
zz
=
ir
::
Substitute
({{
z
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"11"
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
tests/cpp/ir_visitor_test.cc
View file @
8e04361c
...
...
@@ -2,6 +2,7 @@
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
TEST
(
IRVisitor
,
CountVar
)
{
using
namespace
Halide
::
Internal
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment