Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be8de13f
Commit
be8de13f
authored
Nov 06, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable IRFunctor based IRMutator
parent
0a392dd0
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
6 deletions
+502
-6
HalideIR
+1
-1
include/tvm/domain.h
+6
-2
include/tvm/expr.h
+7
-2
include/tvm/ir_mutator.h
+83
-0
src/lang/domain.cc
+8
-0
src/lang/ir.cc
+1
-1
src/pass/ir_mutator.cc
+337
-0
tests/cpp/ir_mutator_test.cc
+59
-0
No files found.
HalideIR
@
89b79399
Subproject commit
ec84af1359c841df622f683048968348381e328a
Subproject commit
89b7939957d66a37dd6083ad6b09a5644e73fd8b
include/tvm/domain.h
View file @
be8de13f
...
@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range {
...
@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range {
* \param end The end of the range.
* \param end The end of the range.
*/
*/
Range
(
Expr
begin
,
Expr
end
);
Range
(
Expr
begin
,
Expr
end
);
static
Range
make_with_min_extent
(
Expr
min
,
Expr
extent
);
};
};
/*! \brief Domain is a multi-dimensional range */
/*! \brief Domain is a multi-dimensional range */
...
@@ -74,6 +76,8 @@ class RDomain : public NodeRef {
...
@@ -74,6 +76,8 @@ class RDomain : public NodeRef {
inline
Var
i0
()
const
{
inline
Var
i0
()
const
{
return
index
(
0
);
return
index
(
0
);
}
}
// low level constructor
static
RDomain
make
(
Array
<
Var
>
index
,
Domain
domain
);
};
};
/*! \brief use RDom as alias of RDomain */
/*! \brief use RDom as alias of RDomain */
...
@@ -88,8 +92,8 @@ class RDomainNode : public Node {
...
@@ -88,8 +92,8 @@ class RDomainNode : public Node {
Domain
domain
;
Domain
domain
;
/*! \brief constructor */
/*! \brief constructor */
RDomainNode
()
{}
RDomainNode
()
{}
RDomainNode
(
Array
<
Var
>
&&
index
,
Domain
&&
domain
)
RDomainNode
(
Array
<
Var
>
index
,
Domain
domain
)
:
index
(
std
::
move
(
index
)),
domain
(
std
::
move
(
domain
)
)
{
:
index
(
index
),
domain
(
domain
)
{
}
}
const
char
*
type_key
()
const
override
{
const
char
*
type_key
()
const
override
{
return
"RDomain"
;
return
"RDomain"
;
...
...
include/tvm/expr.h
View file @
be8de13f
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include <ir/Expr.h>
#include <ir/Expr.h>
#include <ir/IROperator.h>
#include <ir/IROperator.h>
#include <
type_traits
>
#include <
string
>
#include "./base.h"
#include "./base.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -28,7 +28,12 @@ using Halide::select;
...
@@ -28,7 +28,12 @@ using Halide::select;
using
Halide
::
Expr
;
using
Halide
::
Expr
;
using
Halide
::
Internal
::
Stmt
;
using
Halide
::
Internal
::
Stmt
;
using
Var
=
Halide
::
VarExpr
;
class
Var
:
public
Halide
::
VarExpr
{
public
:
explicit
Var
(
const
std
::
string
&
name_hint
=
"v"
,
Type
t
=
Int
(
32
))
:
VarExpr
(
name_hint
,
t
)
{}
};
}
// namespace tvm
}
// namespace tvm
#endif // TVM_EXPR_H_
#endif // TVM_EXPR_H_
include/tvm/ir_mutator.h
0 → 100644
View file @
be8de13f
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.h
* \brief Defines general IRMutation pass
*/
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include "./expr.h"
namespace
tvm
{
namespace
ir
{
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class.
* \sa IRFunctor
*/
class
IRMutator
{
public
:
/*!
* \brief mutate expression
* \return the mutated expr
*/
virtual
Expr
mutate
(
Expr
expr
)
{
static
const
FMutateExpr
&
f
=
vtable_expr
();
return
f
(
expr
,
expr
,
this
);
}
/*!
* \brief mutate expression
* \return the mutated stmt
*/
virtual
Stmt
mutate
(
Stmt
stmt
)
{
static
const
FMutateStmt
&
f
=
vtable_stmt
();
return
f
(
stmt
,
stmt
,
this
);
}
/*! \brief destructor */
virtual
~
IRMutator
()
{}
/*! \brief functor type of expr mutation */
using
FMutateExpr
=
IRFunctor
<
Expr
(
const
IRNodeRef
&
,
const
Expr
&
,
IRMutator
*
)
>
;
/*! \brief functor type of stmt mutation */
using
FMutateStmt
=
IRFunctor
<
Stmt
(
const
IRNodeRef
&
,
const
Stmt
&
,
IRMutator
*
)
>
;
/*! \return internal vtable of expr */
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
/*! \return internal stmt of expr */
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
};
/*!
* \brief templatized base class of subclass of IRMutator
*
* Use "curiously recurring template pattern" to implement mutate for you.
* Child class need to declare IRMutatorBase<T>::vtable_expr and IRMutatorBase<T>::vtable_stmt
*
* \note This only implement direct subclass from IRMutator, similar code
* can be created to implement deeper subclassing when needed.
*/
class
IRMutatorExample
:
public
IRMutator
{
public
:
Expr
mutate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRMutatorExample
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
mutate
(
expr
));
}
Stmt
mutate
(
Stmt
stmt
)
final
{
static
const
FMutateStmt
&
f
=
IRMutatorExample
::
vtable_stmt
();
return
(
f
.
can_dispatch
(
stmt
)
?
f
(
stmt
,
stmt
,
this
)
:
IRMutator
::
mutate
(
stmt
));
}
// to be implemented by child class
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
};
}
// namespace ir
}
// namespace tvm
#endif // TVM_IR_MUTATOR_H_
src/lang/domain.cc
View file @
be8de13f
...
@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end)
...
@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end)
// TODO(tqchen) add simplify to end - begin
// TODO(tqchen) add simplify to end - begin
}
}
Range
Range
::
make_with_min_extent
(
Expr
min
,
Expr
extent
)
{
return
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
min
,
extent
));
}
RDomain
::
RDomain
(
Domain
domain
)
{
RDomain
::
RDomain
(
Domain
domain
)
{
std
::
vector
<
Var
>
index
;
std
::
vector
<
Var
>
index
;
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
...
@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) {
...
@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) {
std
::
move
(
idx
),
std
::
move
(
domain
));
std
::
move
(
idx
),
std
::
move
(
domain
));
}
}
RDomain
RDomain
::
make
(
Array
<
Var
>
index
,
Domain
domain
)
{
return
RDomain
(
std
::
make_shared
<
RDomainNode
>
(
index
,
domain
));
}
TVM_REGISTER_NODE_TYPE
(
RDomainNode
);
TVM_REGISTER_NODE_TYPE
(
RDomainNode
);
}
// namespace tvm
}
// namespace tvm
src/lang/ir.cc
View file @
be8de13f
...
@@ -20,7 +20,7 @@ namespace Internal {
...
@@ -20,7 +20,7 @@ namespace Internal {
using
tvm
::
ir
::
Reduce
;
using
tvm
::
ir
::
Reduce
;
template
<>
template
<>
void
ExprNode
<
Reduce
>::
accept
(
IRVisitor
*
v
)
const
{
void
ExprNode
<
Reduce
>::
accept
(
IRVisitor
*
v
,
const
Expr
&
)
const
{
LOG
(
FATAL
)
<<
"Reduce do not work with IRVisitor yet"
;
LOG
(
FATAL
)
<<
"Reduce do not work with IRVisitor yet"
;
}
}
...
...
src/pass/ir_mutator.cc
0 → 100644
View file @
be8de13f
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
namespace
tvm
{
namespace
ir
{
IRMutator
::
FMutateExpr
&
IRMutator
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
IRMutator
::
FMutateStmt
&
IRMutator
::
vtable_stmt
()
{
// NOLINT(*)
static
FMutateStmt
inst
;
return
inst
;
}
// namespace to register the functors.
namespace
{
using
namespace
Halide
::
Internal
;
// const expr
inline
Expr
ReturnSelfExpr
(
const
IRNodeRef
&
,
const
Expr
&
e
,
IRMutator
*
)
{
return
e
;
}
inline
Array
<
Expr
>
MutateArray
(
Array
<
Expr
>
arr
,
IRMutator
*
m
)
{
std
::
vector
<
Expr
>
new_arr
(
arr
.
size
());
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
Expr
old_elem
=
arr
[
i
];
Expr
new_elem
=
m
->
mutate
(
old_elem
);
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
new_arr
[
i
]
=
new_elem
;
}
if
(
!
changed
)
{
return
arr
;
}
else
{
return
Array
<
Expr
>
(
new_arr
);
}
}
inline
RDomain
MutateRDom
(
RDomain
rdom
,
IRMutator
*
m
)
{
std
::
vector
<
Range
>
new_dom
(
rdom
->
domain
.
size
());
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
Expr
new_min
=
m
->
mutate
(
r
->
min
);
Expr
new_extent
=
m
->
mutate
(
r
->
extent
);
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
new_dom
[
i
]
=
Range
::
make_with_min_extent
(
new_min
,
new_extent
);
}
if
(
!
changed
)
{
return
rdom
;
}
else
{
return
RDomain
::
make
(
rdom
->
index
,
Domain
(
new_dom
));
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
RDomain
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Expr
new_source
=
m
->
mutate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
op
->
source
.
same_as
(
new_source
))
{
return
e
;
}
else
{
return
Reduce
::
make
(
op
->
op
,
new_source
,
new_rdom
);
}
});
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
IntImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
UIntImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
FloatImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
StringImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
Variable
>
(
ReturnSelfExpr
);
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
return
Cast
::
make
(
op
->
type
,
value
);
}
});
// binary operator
template
<
typename
T
>
inline
Expr
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
mutate
(
op
->
a
);
Expr
b
=
m
->
mutate
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
e
;
}
else
{
return
T
::
make
(
a
,
b
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
)
.
set_dispatch
<
EQ
>
(
Binary
<
EQ
>
)
.
set_dispatch
<
NE
>
(
Binary
<
NE
>
)
.
set_dispatch
<
LT
>
(
Binary
<
LT
>
)
.
set_dispatch
<
LE
>
(
Binary
<
LE
>
)
.
set_dispatch
<
GT
>
(
Binary
<
GT
>
)
.
set_dispatch
<
GE
>
(
Binary
<
GE
>
)
.
set_dispatch
<
And
>
(
Binary
<
And
>
)
.
set_dispatch
<
Or
>
(
Binary
<
Or
>
);
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
mutate
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
return
e
;
}
else
{
return
Not
::
make
(
a
);
}
})
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
cond
=
m
->
mutate
(
op
->
condition
);
Expr
t
=
m
->
mutate
(
op
->
true_value
);
Expr
f
=
m
->
mutate
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
return
e
;
}
else
{
return
Select
::
make
(
cond
,
t
,
f
);
}
})
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
index
=
m
->
mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
return
e
;
}
else
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
);
}
})
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
base
=
m
->
mutate
(
op
->
base
);
Expr
stride
=
m
->
mutate
(
op
->
stride
);
if
(
base
.
same_as
(
op
->
base
)
&&
stride
.
same_as
(
op
->
stride
))
{
return
e
;
}
else
{
return
Ramp
::
make
(
base
,
stride
,
op
->
lanes
);
}
})
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
return
Broadcast
::
make
(
value
,
op
->
lanes
);
}
})
.
set_dispatch
<
Call
>
([](
const
Call
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
auto
new_args
=
MutateArray
(
op
->
args
,
m
);
if
(
op
->
args
.
same_as
(
new_args
))
{
return
e
;
}
else
{
return
Call
::
make
(
op
->
type
,
op
->
name
,
new_args
,
op
->
call_type
,
op
->
func
,
op
->
value_index
);
}
})
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
Expr
body
=
m
->
mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
e
;
}
else
{
return
Let
::
make
(
op
->
var
,
value
,
body
);
}
});
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_stmt
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
Stmt
body
=
m
->
mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
LetStmt
::
make
(
op
->
var
,
value
,
body
);
}
})
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
mutate
(
op
->
condition
);
Expr
message
=
m
->
mutate
(
op
->
message
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
return
s
;
}
else
{
return
AssertStmt
::
make
(
condition
,
message
);
}
})
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
body
=
m
->
mutate
(
op
->
body
);
if
(
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
ProducerConsumer
::
make
(
op
->
func
,
op
->
is_producer
,
body
);
}
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
min
=
m
->
mutate
(
op
->
min
);
Expr
extent
=
m
->
mutate
(
op
->
extent
);
Stmt
body
=
m
->
mutate
(
op
->
body
);
if
(
min
.
same_as
(
op
->
min
)
&&
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
For
::
make
(
op
->
loop_var
,
min
,
extent
,
op
->
for_type
,
op
->
device_api
,
body
);
}
})
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
Expr
index
=
m
->
mutate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
}
else
{
return
Store
::
make
(
op
->
buffer_var
,
value
,
index
);
}
})
.
set_dispatch
<
Provide
>
([](
const
Provide
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
auto
new_args
=
MutateArray
(
op
->
args
,
m
);
auto
new_values
=
MutateArray
(
op
->
values
,
m
);
if
(
op
->
args
.
same_as
(
new_args
)
&&
op
->
values
.
same_as
(
new_values
))
{
return
s
;
}
else
{
return
Provide
::
make
(
op
->
func
,
new_values
,
new_args
);
}
})
.
set_dispatch
<
Allocate
>
([](
const
Allocate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
std
::
vector
<
Expr
>
new_extents
;
bool
all_extents_unmodified
=
true
;
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
new_extents
.
push_back
(
m
->
mutate
(
op
->
extents
[
i
]));
all_extents_unmodified
&=
new_extents
[
i
].
same_as
(
op
->
extents
[
i
]);
}
Stmt
body
=
m
->
mutate
(
op
->
body
);
Expr
condition
=
m
->
mutate
(
op
->
condition
);
Expr
new_expr
;
if
(
op
->
new_expr
.
defined
())
{
new_expr
=
m
->
mutate
(
op
->
new_expr
);
}
if
(
all_extents_unmodified
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
)
&&
new_expr
.
same_as
(
op
->
new_expr
))
{
return
s
;
}
else
{
return
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
new_extents
,
condition
,
body
,
new_expr
,
op
->
free_function
);
}
})
.
set_dispatch
<
Free
>
([](
const
Free
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
return
s
;
})
.
set_dispatch
<
Realize
>
([](
const
Realize
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Region
new_bounds
;
bool
bounds_changed
=
false
;
// Mutate the bounds
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
Expr
old_min
=
op
->
bounds
[
i
]
->
min
;
Expr
old_extent
=
op
->
bounds
[
i
]
->
extent
;
Expr
new_min
=
m
->
mutate
(
old_min
);
Expr
new_extent
=
m
->
mutate
(
old_extent
);
if
(
!
new_min
.
same_as
(
old_min
))
bounds_changed
=
true
;
if
(
!
new_extent
.
same_as
(
old_extent
))
bounds_changed
=
true
;
new_bounds
.
push_back
(
Range
::
make_by_min_extent
(
new_min
,
new_extent
));
}
Stmt
body
=
m
->
mutate
(
op
->
body
);
Expr
condition
=
m
->
mutate
(
op
->
condition
);
if
(
!
bounds_changed
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
))
{
return
s
;
}
else
{
return
Realize
::
make
(
op
->
func
,
op
->
types
,
new_bounds
,
condition
,
body
);
}
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
first
=
m
->
mutate
(
op
->
first
);
Stmt
rest
=
m
->
mutate
(
op
->
rest
);
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
mutate
(
op
->
condition
);
Stmt
then_case
=
m
->
mutate
(
op
->
then_case
);
Stmt
else_case
=
m
->
mutate
(
op
->
else_case
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
condition
,
then_case
,
else_case
);
}
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
v
=
m
->
mutate
(
op
->
value
);
if
(
v
.
same_as
(
op
->
value
))
{
return
s
;
}
else
{
return
Evaluate
::
make
(
v
);
}
});
}
// namespace
}
// namespace ir
}
// namespace tvm
tests/cpp/ir_mutator_test.cc
0 → 100644
View file @
be8de13f
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace
{
using
namespace
tvm
::
ir
;
using
namespace
Halide
::
Internal
;
using
namespace
Halide
;
// replace variable to constant
class
IRVar2Const
:
public
IRMutator
{
public
:
VarExpr
var
;
int
int_val
;
Expr
mutate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRVar2Const
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
mutate
(
expr
));
}
static
FMutateExpr
&
vtable_expr
();
};
// implement vtable
IRMutator
::
FMutateExpr
&
IRVar2Const
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
TVM_STATIC_IR_FUNCTOR
(
IRVar2Const
,
vtable_expr
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
IRVar2Const
*
vm
=
static_cast
<
IRVar2Const
*>
(
m
);
if
(
e
.
same_as
(
vm
->
var
))
{
return
IntImm
::
make
(
Int
(
32
),
vm
->
int_val
);
}
else
{
return
e
;
}
});
}
// namespace
TEST
(
IRMutator
,
Basic
)
{
using
namespace
Halide
::
Internal
;
using
namespace
tvm
;
Var
x
(
"x"
),
y
;
auto
z
=
x
+
y
;
IRVar2Const
mu
;
mu
.
var
=
y
;
mu
.
int_val
=
10
;
auto
zz
=
mu
.
mutate
(
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 10)"
);
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment