Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8e04361c
Commit
8e04361c
authored
Nov 18, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor IR Pass
parent
ff6b8d82
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
203 additions
and
153 deletions
+203
-153
HalideIR
+1
-1
include/tvm/expr.h
+1
-0
include/tvm/ir_mutator.h
+6
-13
include/tvm/ir_pass.h
+26
-0
include/tvm/ir_visitor.h
+1
-1
include/tvm/tensor.h
+7
-1
src/pass/ir_mutator.cc
+41
-67
src/pass/ir_pass.cc
+39
-0
src/pass/ir_visitor.cc
+48
-48
tests/cpp/ir_mutator_test.cc
+3
-22
tests/cpp/ir_pass_test.cc
+29
-0
tests/cpp/ir_visitor_test.cc
+1
-0
No files found.
HalideIR
@
4becbde6
Subproject commit
89b7939957d66a37dd6083ad6b09a5644e73fd8b
Subproject commit
4becbde67c8aa565941b02648cea90f50211f8dc
include/tvm/expr.h
View file @
8e04361c
...
@@ -27,6 +27,7 @@ using Halide::abs;
...
@@ -27,6 +27,7 @@ using Halide::abs;
using
Halide
::
select
;
using
Halide
::
select
;
using
Halide
::
Expr
;
using
Halide
::
Expr
;
using
Halide
::
IR
::
FunctionBaseNode
;
using
Halide
::
Internal
::
Stmt
;
using
Halide
::
Internal
::
Stmt
;
class
Var
:
public
Halide
::
VarExpr
{
class
Var
:
public
Halide
::
VarExpr
{
...
...
include/tvm/ir_mutator.h
View file @
8e04361c
...
@@ -29,7 +29,7 @@ class IRMutator {
...
@@ -29,7 +29,7 @@ class IRMutator {
* \brief mutate expression
* \brief mutate expression
* \return the mutated expr
* \return the mutated expr
*/
*/
virtual
Expr
m
utate
(
Expr
expr
)
{
virtual
Expr
M
utate
(
Expr
expr
)
{
static
const
FMutateExpr
&
f
=
vtable_expr
();
static
const
FMutateExpr
&
f
=
vtable_expr
();
return
f
(
expr
,
expr
,
this
);
return
f
(
expr
,
expr
,
this
);
}
}
...
@@ -37,7 +37,7 @@ class IRMutator {
...
@@ -37,7 +37,7 @@ class IRMutator {
* \brief mutate expression
* \brief mutate expression
* \return the mutated stmt
* \return the mutated stmt
*/
*/
virtual
Stmt
m
utate
(
Stmt
stmt
)
{
virtual
Stmt
M
utate
(
Stmt
stmt
)
{
static
const
FMutateStmt
&
f
=
vtable_stmt
();
static
const
FMutateStmt
&
f
=
vtable_stmt
();
return
f
(
stmt
,
stmt
,
this
);
return
f
(
stmt
,
stmt
,
this
);
}
}
...
@@ -58,28 +58,21 @@ class IRMutator {
...
@@ -58,28 +58,21 @@ class IRMutator {
*/
*/
class
IRMutatorExample
:
public
IRMutator
{
class
IRMutatorExample
:
public
IRMutator
{
public
:
public
:
Expr
m
utate
(
Expr
expr
)
final
{
Expr
M
utate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRMutatorExample
::
vtable_expr
();
static
const
FMutateExpr
&
f
=
IRMutatorExample
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
m
utate
(
expr
));
f
(
expr
,
expr
,
this
)
:
IRMutator
::
M
utate
(
expr
));
}
}
Stmt
m
utate
(
Stmt
stmt
)
final
{
Stmt
M
utate
(
Stmt
stmt
)
final
{
static
const
FMutateStmt
&
f
=
IRMutatorExample
::
vtable_stmt
();
static
const
FMutateStmt
&
f
=
IRMutatorExample
::
vtable_stmt
();
return
(
f
.
can_dispatch
(
stmt
)
?
return
(
f
.
can_dispatch
(
stmt
)
?
f
(
stmt
,
stmt
,
this
)
:
IRMutator
::
m
utate
(
stmt
));
f
(
stmt
,
stmt
,
this
)
:
IRMutator
::
M
utate
(
stmt
));
}
}
// to be implemented by child class
// to be implemented by child class
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
};
};
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
);
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
#endif // TVM_IR_MUTATOR_H_
#endif // TVM_IR_MUTATOR_H_
include/tvm/ir_pass.h
0 → 100644
View file @
8e04361c
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace
tvm
{
namespace
ir
{
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
);
}
// namespace ir
}
// namespace tvm
#endif // TVM_IR_PASS_H_
include/tvm/ir_visitor.h
View file @
8e04361c
...
@@ -24,7 +24,7 @@ class IRVisitor {
...
@@ -24,7 +24,7 @@ class IRVisitor {
/*!
/*!
* \brief recursively visit an IR node
* \brief recursively visit an IR node
*/
*/
virtual
void
v
isit
(
const
IRNodeRef
&
node
)
{
virtual
void
V
isit
(
const
IRNodeRef
&
node
)
{
static
const
FVisit
&
f
=
vtable
();
static
const
FVisit
&
f
=
vtable
();
if
(
node
.
defined
())
f
(
node
,
this
);
if
(
node
.
defined
())
f
(
node
,
this
);
}
}
...
...
include/tvm/tensor.h
View file @
8e04361c
...
@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
...
@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
};
};
/*! \brief Node to represent a tensor */
/*! \brief Node to represent a tensor */
class
TensorNode
:
public
Node
{
class
TensorNode
:
public
FunctionBase
Node
{
public
:
public
:
/*! \brief The shape of the tensor */
/*! \brief The shape of the tensor */
Array
<
Expr
>
shape
;
Array
<
Expr
>
shape
;
...
@@ -125,6 +125,12 @@ class TensorNode : public Node {
...
@@ -125,6 +125,12 @@ class TensorNode : public Node {
v
->
Visit
(
"dim_var"
,
&
dim_var
);
v
->
Visit
(
"dim_var"
,
&
dim_var
);
v
->
Visit
(
"source"
,
&
source
);
v
->
Visit
(
"source"
,
&
source
);
}
}
const
std
::
string
&
func_name
()
const
final
{
return
name
;
}
int
outputs
()
const
final
{
return
1
;
}
static
Tensor
make
(
Array
<
Expr
>
shape
,
static
Tensor
make
(
Array
<
Expr
>
shape
,
std
::
string
name
,
std
::
string
name
,
Type
dtype
,
Type
dtype
,
...
...
src/pass/ir_mutator.cc
View file @
8e04361c
...
@@ -8,32 +8,6 @@
...
@@ -8,32 +8,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
namespace
{
// visitor to implement apply
class
IRSubstitute
:
public
IRMutator
{
public
:
Expr
mutate
(
Expr
expr
)
final
{
const
IRNode
*
v
=
expr
.
get
();
if
(
v
!=
nullptr
)
{
auto
it
=
replacements_
.
find
(
v
);
if
(
it
!=
replacements_
.
end
())
{
return
it
->
second
;
}
}
return
IRMutator
::
mutate
(
expr
);
}
explicit
IRSubstitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
)
:
replacements_
(
replacements
)
{}
private
:
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements_
;
};
}
// namespace
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
)
{
return
IRSubstitute
(
replacements
).
mutate
(
expr
);
}
IRMutator
::
FMutateExpr
&
IRMutator
::
vtable_expr
()
{
// NOLINT(*)
IRMutator
::
FMutateExpr
&
IRMutator
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
static
FMutateExpr
inst
;
return
inst
;
}
}
...
@@ -57,7 +31,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
...
@@ -57,7 +31,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
bool
changed
=
false
;
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
Expr
old_elem
=
arr
[
i
];
Expr
old_elem
=
arr
[
i
];
Expr
new_elem
=
m
->
m
utate
(
old_elem
);
Expr
new_elem
=
m
->
M
utate
(
old_elem
);
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
new_arr
[
i
]
=
new_elem
;
new_arr
[
i
]
=
new_elem
;
}
}
...
@@ -73,8 +47,8 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
...
@@ -73,8 +47,8 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
bool
changed
=
false
;
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
Range
r
=
rdom
->
domain
[
i
];
Expr
new_min
=
m
->
m
utate
(
r
->
min
);
Expr
new_min
=
m
->
M
utate
(
r
->
min
);
Expr
new_extent
=
m
->
m
utate
(
r
->
extent
);
Expr
new_extent
=
m
->
M
utate
(
r
->
extent
);
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
new_dom
[
i
]
=
Range
::
make_with_min_extent
(
new_min
,
new_extent
);
new_dom
[
i
]
=
Range
::
make_with_min_extent
(
new_min
,
new_extent
);
...
@@ -89,7 +63,7 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
...
@@ -89,7 +63,7 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
RDomain
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
RDomain
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Expr
new_source
=
m
->
m
utate
(
op
->
source
);
Expr
new_source
=
m
->
M
utate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
op
->
source
.
same_as
(
new_source
))
{
op
->
source
.
same_as
(
new_source
))
{
return
e
;
return
e
;
...
@@ -107,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -107,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
return
e
;
}
else
{
}
else
{
...
@@ -118,8 +92,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -118,8 +92,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
// binary operator
// binary operator
template
<
typename
T
>
template
<
typename
T
>
inline
Expr
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
inline
Expr
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
m
utate
(
op
->
a
);
Expr
a
=
m
->
M
utate
(
op
->
a
);
Expr
b
=
m
->
m
utate
(
op
->
b
);
Expr
b
=
m
->
M
utate
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
b
.
same_as
(
op
->
b
))
{
return
e
;
return
e
;
...
@@ -147,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -147,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
m
utate
(
op
->
a
);
Expr
a
=
m
->
M
utate
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
if
(
a
.
same_as
(
op
->
a
))
{
return
e
;
return
e
;
}
else
{
}
else
{
...
@@ -155,9 +129,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -155,9 +129,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
}
})
})
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
cond
=
m
->
m
utate
(
op
->
condition
);
Expr
cond
=
m
->
M
utate
(
op
->
condition
);
Expr
t
=
m
->
m
utate
(
op
->
true_value
);
Expr
t
=
m
->
M
utate
(
op
->
true_value
);
Expr
f
=
m
->
m
utate
(
op
->
false_value
);
Expr
f
=
m
->
M
utate
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
f
.
same_as
(
op
->
false_value
))
{
...
@@ -167,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -167,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
}
})
})
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
index
=
m
->
m
utate
(
op
->
index
);
Expr
index
=
m
->
M
utate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
if
(
index
.
same_as
(
op
->
index
))
{
return
e
;
return
e
;
}
else
{
}
else
{
...
@@ -175,8 +149,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -175,8 +149,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
}
})
})
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
base
=
m
->
m
utate
(
op
->
base
);
Expr
base
=
m
->
M
utate
(
op
->
base
);
Expr
stride
=
m
->
m
utate
(
op
->
stride
);
Expr
stride
=
m
->
M
utate
(
op
->
stride
);
if
(
base
.
same_as
(
op
->
base
)
&&
if
(
base
.
same_as
(
op
->
base
)
&&
stride
.
same_as
(
op
->
stride
))
{
stride
.
same_as
(
op
->
stride
))
{
return
e
;
return
e
;
...
@@ -185,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -185,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
}
})
})
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
return
e
;
}
else
{
}
else
{
...
@@ -202,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -202,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
}
})
})
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
Expr
body
=
m
->
m
utate
(
op
->
body
);
Expr
body
=
m
->
M
utate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
body
.
same_as
(
op
->
body
))
{
return
e
;
return
e
;
...
@@ -214,8 +188,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -214,8 +188,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_stmt
)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_stmt
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
body
.
same_as
(
op
->
body
))
{
return
s
;
return
s
;
...
@@ -224,8 +198,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -224,8 +198,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
Expr
message
=
m
->
m
utate
(
op
->
message
);
Expr
message
=
m
->
M
utate
(
op
->
message
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
return
s
;
return
s
;
...
@@ -234,7 +208,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -234,7 +208,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
if
(
body
.
same_as
(
op
->
body
))
{
if
(
body
.
same_as
(
op
->
body
))
{
return
s
;
return
s
;
}
else
{
}
else
{
...
@@ -242,9 +216,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -242,9 +216,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
For
>
([](
const
For
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
min
=
m
->
m
utate
(
op
->
min
);
Expr
min
=
m
->
M
utate
(
op
->
min
);
Expr
extent
=
m
->
m
utate
(
op
->
extent
);
Expr
extent
=
m
->
M
utate
(
op
->
extent
);
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
if
(
min
.
same_as
(
op
->
min
)
&&
if
(
min
.
same_as
(
op
->
min
)
&&
extent
.
same_as
(
op
->
extent
)
&&
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
body
.
same_as
(
op
->
body
))
{
...
@@ -255,8 +229,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -255,8 +229,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
m
utate
(
op
->
value
);
Expr
value
=
m
->
M
utate
(
op
->
value
);
Expr
index
=
m
->
m
utate
(
op
->
index
);
Expr
index
=
m
->
M
utate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
return
s
;
}
else
{
}
else
{
...
@@ -276,14 +250,14 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -276,14 +250,14 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
std
::
vector
<
Expr
>
new_extents
;
std
::
vector
<
Expr
>
new_extents
;
bool
all_extents_unmodified
=
true
;
bool
all_extents_unmodified
=
true
;
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
new_extents
.
push_back
(
m
->
m
utate
(
op
->
extents
[
i
]));
new_extents
.
push_back
(
m
->
M
utate
(
op
->
extents
[
i
]));
all_extents_unmodified
&=
new_extents
[
i
].
same_as
(
op
->
extents
[
i
]);
all_extents_unmodified
&=
new_extents
[
i
].
same_as
(
op
->
extents
[
i
]);
}
}
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
Expr
new_expr
;
Expr
new_expr
;
if
(
op
->
new_expr
.
defined
())
{
if
(
op
->
new_expr
.
defined
())
{
new_expr
=
m
->
m
utate
(
op
->
new_expr
);
new_expr
=
m
->
M
utate
(
op
->
new_expr
);
}
}
if
(
all_extents_unmodified
&&
if
(
all_extents_unmodified
&&
body
.
same_as
(
op
->
body
)
&&
body
.
same_as
(
op
->
body
)
&&
...
@@ -308,16 +282,16 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -308,16 +282,16 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
Expr
old_min
=
op
->
bounds
[
i
]
->
min
;
Expr
old_min
=
op
->
bounds
[
i
]
->
min
;
Expr
old_extent
=
op
->
bounds
[
i
]
->
extent
;
Expr
old_extent
=
op
->
bounds
[
i
]
->
extent
;
Expr
new_min
=
m
->
m
utate
(
old_min
);
Expr
new_min
=
m
->
M
utate
(
old_min
);
Expr
new_extent
=
m
->
m
utate
(
old_extent
);
Expr
new_extent
=
m
->
M
utate
(
old_extent
);
if
(
!
new_min
.
same_as
(
old_min
))
bounds_changed
=
true
;
if
(
!
new_min
.
same_as
(
old_min
))
bounds_changed
=
true
;
if
(
!
new_extent
.
same_as
(
old_extent
))
bounds_changed
=
true
;
if
(
!
new_extent
.
same_as
(
old_extent
))
bounds_changed
=
true
;
new_bounds
.
push_back
(
new_bounds
.
push_back
(
Range
::
make_by_min_extent
(
new_min
,
new_extent
));
Range
::
make_by_min_extent
(
new_min
,
new_extent
));
}
}
Stmt
body
=
m
->
m
utate
(
op
->
body
);
Stmt
body
=
m
->
M
utate
(
op
->
body
);
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
if
(
!
bounds_changed
&&
if
(
!
bounds_changed
&&
body
.
same_as
(
op
->
body
)
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
))
{
condition
.
same_as
(
op
->
condition
))
{
...
@@ -328,8 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -328,8 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
first
=
m
->
m
utate
(
op
->
first
);
Stmt
first
=
m
->
M
utate
(
op
->
first
);
Stmt
rest
=
m
->
m
utate
(
op
->
rest
);
Stmt
rest
=
m
->
M
utate
(
op
->
rest
);
if
(
first
.
same_as
(
op
->
first
)
&&
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
rest
.
same_as
(
op
->
rest
))
{
return
s
;
return
s
;
...
@@ -338,9 +312,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -338,9 +312,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
m
utate
(
op
->
condition
);
Expr
condition
=
m
->
M
utate
(
op
->
condition
);
Stmt
then_case
=
m
->
m
utate
(
op
->
then_case
);
Stmt
then_case
=
m
->
M
utate
(
op
->
then_case
);
Stmt
else_case
=
m
->
m
utate
(
op
->
else_case
);
Stmt
else_case
=
m
->
M
utate
(
op
->
else_case
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
else_case
.
same_as
(
op
->
else_case
))
{
...
@@ -350,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
...
@@ -350,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
}
})
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
v
=
m
->
m
utate
(
op
->
value
);
Expr
v
=
m
->
M
utate
(
op
->
value
);
if
(
v
.
same_as
(
op
->
value
))
{
if
(
v
.
same_as
(
op
->
value
))
{
return
s
;
return
s
;
}
else
{
}
else
{
...
...
src/pass/ir_pass.cc
0 → 100644
View file @
8e04361c
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace
tvm
{
namespace
ir
{
namespace
{
// visitor to implement apply
class
IRSubstitute
:
public
IRMutator
{
public
:
Expr
Mutate
(
Expr
expr
)
final
{
const
IRNode
*
v
=
expr
.
get
();
if
(
v
!=
nullptr
)
{
auto
it
=
replacements_
.
find
(
v
);
if
(
it
!=
replacements_
.
end
())
{
return
it
->
second
;
}
}
return
IRMutator
::
Mutate
(
expr
);
}
explicit
IRSubstitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
)
:
replacements_
(
replacements
)
{}
private
:
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements_
;
};
}
// namespace
Expr
Substitute
(
const
std
::
unordered_map
<
const
IRNode
*
,
Expr
>&
replacements
,
Expr
expr
)
{
return
IRSubstitute
(
replacements
).
Mutate
(
expr
);
}
}
// namespace ir
}
// namespace tvm
src/pass/ir_visitor.cc
View file @
8e04361c
...
@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
...
@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
public
:
public
:
explicit
IRApplyVisit
(
std
::
function
<
void
(
const
IRNodeRef
&
)
>
f
)
:
f_
(
f
)
{}
explicit
IRApplyVisit
(
std
::
function
<
void
(
const
IRNodeRef
&
)
>
f
)
:
f_
(
f
)
{}
void
v
isit
(
const
IRNodeRef
&
node
)
final
{
void
V
isit
(
const
IRNodeRef
&
node
)
final
{
if
(
visited_
.
count
(
node
.
get
())
!=
0
)
return
;
if
(
visited_
.
count
(
node
.
get
())
!=
0
)
return
;
visited_
.
insert
(
node
.
get
());
visited_
.
insert
(
node
.
get
());
IRVisitor
::
v
isit
(
node
);
IRVisitor
::
V
isit
(
node
);
f_
(
node
);
f_
(
node
);
}
}
...
@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
...
@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
std
::
function
<
void
(
const
IRNodeRef
&
)
>
f_
;
std
::
function
<
void
(
const
IRNodeRef
&
)
>
f_
;
std
::
unordered_set
<
const
Node
*>
visited_
;
std
::
unordered_set
<
const
Node
*>
visited_
;
};
};
}
// namespace
}
// namespace
void
PostOrderVisit
(
const
IRNodeRef
&
node
,
std
::
function
<
void
(
const
IRNodeRef
&
)
>
fvisit
)
{
IRApplyVisit
(
fvisit
).
Visit
(
node
);
}
IRVisitor
::
FVisit
&
IRVisitor
::
vtable
()
{
// NOLINT(*)
IRVisitor
::
FVisit
&
IRVisitor
::
vtable
()
{
// NOLINT(*)
static
FVisit
inst
;
return
inst
;
static
FVisit
inst
;
return
inst
;
}
}
void
PostOrderVisit
(
const
IRNodeRef
&
node
,
std
::
function
<
void
(
const
IRNodeRef
&
)
>
fvisit
)
{
IRApplyVisit
v
(
fvisit
);
v
.
visit
(
node
);
}
// namespace to register the functors.
// namespace to register the functors.
namespace
{
namespace
{
...
@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
...
@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
inline
void
VisitArray
(
Array
<
Expr
>
arr
,
IRVisitor
*
v
)
{
inline
void
VisitArray
(
Array
<
Expr
>
arr
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
v
->
v
isit
(
arr
[
i
]);
v
->
V
isit
(
arr
[
i
]);
}
}
}
}
inline
void
VisitRDom
(
RDomain
rdom
,
IRVisitor
*
v
)
{
inline
void
VisitRDom
(
RDomain
rdom
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
Range
r
=
rdom
->
domain
[
i
];
v
->
v
isit
(
r
->
min
);
v
->
V
isit
(
r
->
min
);
v
->
v
isit
(
r
->
extent
);
v
->
V
isit
(
r
->
extent
);
}
}
}
}
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
IRVisitor
*
v
)
{
VisitRDom
(
op
->
rdom
,
v
);
VisitRDom
(
op
->
rdom
,
v
);
v
->
v
isit
(
op
->
source
);
v
->
V
isit
(
op
->
source
);
});
});
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
...
@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...
@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
});
});
// binary operator
// binary operator
template
<
typename
T
>
template
<
typename
T
>
inline
void
Binary
(
const
T
*
op
,
IRVisitor
*
v
)
{
inline
void
Binary
(
const
T
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
a
);
v
->
V
isit
(
op
->
a
);
v
->
v
isit
(
op
->
b
);
v
->
V
isit
(
op
->
b
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
...
@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...
@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
a
);
v
->
V
isit
(
op
->
a
);
})
})
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
condition
);
v
->
v
isit
(
op
->
true_value
);
v
->
V
isit
(
op
->
true_value
);
v
->
v
isit
(
op
->
false_value
);
v
->
V
isit
(
op
->
false_value
);
})
})
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
index
);
v
->
V
isit
(
op
->
index
);
})
})
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
base
);
v
->
V
isit
(
op
->
base
);
v
->
v
isit
(
op
->
stride
);
v
->
V
isit
(
op
->
stride
);
})
})
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
})
})
.
set_dispatch
<
Call
>
([](
const
Call
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Call
>
([](
const
Call
*
op
,
IRVisitor
*
v
)
{
VisitArray
(
op
->
args
,
v
);
VisitArray
(
op
->
args
,
v
);
})
})
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
});
});
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
})
})
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
condition
);
v
->
v
isit
(
op
->
message
);
v
->
V
isit
(
op
->
message
);
})
})
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
})
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
For
>
([](
const
For
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
min
);
v
->
V
isit
(
op
->
min
);
v
->
v
isit
(
op
->
extent
);
v
->
V
isit
(
op
->
extent
);
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
})
})
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
v
->
v
isit
(
op
->
index
);
v
->
V
isit
(
op
->
index
);
})
})
.
set_dispatch
<
Provide
>
([](
const
Provide
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Provide
>
([](
const
Provide
*
op
,
IRVisitor
*
v
)
{
VisitArray
(
op
->
args
,
v
);
VisitArray
(
op
->
args
,
v
);
...
@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
...
@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
})
.
set_dispatch
<
Allocate
>
([](
const
Allocate
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Allocate
>
([](
const
Allocate
*
op
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
v
->
v
isit
(
op
->
extents
[
i
]);
v
->
V
isit
(
op
->
extents
[
i
]);
}
}
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
condition
);
if
(
op
->
new_expr
.
defined
())
{
if
(
op
->
new_expr
.
defined
())
{
v
->
v
isit
(
op
->
new_expr
);
v
->
V
isit
(
op
->
new_expr
);
}
}
})
})
.
set_dispatch
<
Free
>
(
NoOp
)
.
set_dispatch
<
Free
>
(
NoOp
)
.
set_dispatch
<
Realize
>
([](
const
Realize
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Realize
>
([](
const
Realize
*
op
,
IRVisitor
*
v
)
{
// Mutate the bounds
// Mutate the bounds
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
v
->
v
isit
(
op
->
bounds
[
i
]
->
min
);
v
->
V
isit
(
op
->
bounds
[
i
]
->
min
);
v
->
v
isit
(
op
->
bounds
[
i
]
->
extent
);
v
->
V
isit
(
op
->
bounds
[
i
]
->
extent
);
}
}
v
->
v
isit
(
op
->
body
);
v
->
V
isit
(
op
->
body
);
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
condition
);
})
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
first
);
v
->
V
isit
(
op
->
first
);
v
->
v
isit
(
op
->
rest
);
v
->
V
isit
(
op
->
rest
);
})
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
condition
);
v
->
V
isit
(
op
->
condition
);
v
->
v
isit
(
op
->
then_case
);
v
->
V
isit
(
op
->
then_case
);
v
->
v
isit
(
op
->
else_case
);
v
->
V
isit
(
op
->
else_case
);
})
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
IRVisitor
*
v
)
{
v
->
v
isit
(
op
->
value
);
v
->
V
isit
(
op
->
value
);
});
});
}
// namespace
}
// namespace
...
...
tests/cpp/ir_mutator_test.cc
View file @
8e04361c
...
@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
...
@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
public
:
public
:
VarExpr
var
;
VarExpr
var
;
int
int_val
;
int
int_val
;
Expr
m
utate
(
Expr
expr
)
final
{
Expr
M
utate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRVar2Const
::
vtable_expr
();
static
const
FMutateExpr
&
f
=
IRVar2Const
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
m
utate
(
expr
));
f
(
expr
,
expr
,
this
)
:
IRMutator
::
M
utate
(
expr
));
}
}
static
FMutateExpr
&
vtable_expr
();
static
FMutateExpr
&
vtable_expr
();
};
};
...
@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
...
@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
IRVar2Const
mu
;
IRVar2Const
mu
;
mu
.
var
=
y
;
mu
.
var
=
y
;
mu
.
int_val
=
10
;
mu
.
int_val
=
10
;
auto
zz
=
mu
.
m
utate
(
z
);
auto
zz
=
mu
.
M
utate
(
z
);
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
zz
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 10)"
);
CHECK
(
os
.
str
()
==
"(x + 10)"
);
}
}
TEST
(
IRMutator
,
Substitute
)
{
using
namespace
Halide
::
Internal
;
using
namespace
tvm
;
Var
x
(
"x"
),
y
;
auto
z
=
x
+
y
;
{
auto
zz
=
Substitute
({{
y
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 11)"
);
}
{
auto
zz
=
Substitute
({{
z
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"11"
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
tests/cpp/ir_pass_test.cc
0 → 100644
View file @
8e04361c
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST
(
IRPass
,
Substitute
)
{
using
namespace
Halide
::
Internal
;
using
namespace
tvm
;
Var
x
(
"x"
),
y
;
auto
z
=
x
+
y
;
{
auto
zz
=
ir
::
Substitute
({{
y
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 11)"
);
}
{
auto
zz
=
ir
::
Substitute
({{
z
.
get
(),
11
}},
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"11"
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
tests/cpp/ir_visitor_test.cc
View file @
8e04361c
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/tvm.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
TEST
(
IRVisitor
,
CountVar
)
{
TEST
(
IRVisitor
,
CountVar
)
{
using
namespace
Halide
::
Internal
;
using
namespace
Halide
::
Internal
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment