Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be8de13f
Commit
be8de13f
authored
8 years ago
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable IRFunctor based IRMutator
parent
0a392dd0
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
6 deletions
+502
-6
HalideIR
+1
-1
include/tvm/domain.h
+6
-2
include/tvm/expr.h
+7
-2
include/tvm/ir_mutator.h
+83
-0
src/lang/domain.cc
+8
-0
src/lang/ir.cc
+1
-1
src/pass/ir_mutator.cc
+337
-0
tests/cpp/ir_mutator_test.cc
+59
-0
No files found.
HalideIR
@
89b79399
Subproject commit
ec84af1359c841df622f683048968348381e328a
Subproject commit
89b7939957d66a37dd6083ad6b09a5644e73fd8b
This diff is collapsed.
Click to expand it.
include/tvm/domain.h
View file @
be8de13f
...
...
@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range {
* \param end The end of the range.
*/
Range
(
Expr
begin
,
Expr
end
);
static
Range
make_with_min_extent
(
Expr
min
,
Expr
extent
);
};
/*! \brief Domain is a multi-dimensional range */
...
...
@@ -74,6 +76,8 @@ class RDomain : public NodeRef {
inline
Var
i0
()
const
{
return
index
(
0
);
}
// low level constructor
static
RDomain
make
(
Array
<
Var
>
index
,
Domain
domain
);
};
/*! \brief use RDom as alias of RDomain */
...
...
@@ -88,8 +92,8 @@ class RDomainNode : public Node {
Domain
domain
;
/*! \brief constructor */
RDomainNode
()
{}
RDomainNode
(
Array
<
Var
>
&&
index
,
Domain
&&
domain
)
:
index
(
std
::
move
(
index
)),
domain
(
std
::
move
(
domain
)
)
{
RDomainNode
(
Array
<
Var
>
index
,
Domain
domain
)
:
index
(
index
),
domain
(
domain
)
{
}
const
char
*
type_key
()
const
override
{
return
"RDomain"
;
...
...
This diff is collapsed.
Click to expand it.
include/tvm/expr.h
View file @
be8de13f
...
...
@@ -8,7 +8,7 @@
#include <ir/Expr.h>
#include <ir/IROperator.h>
#include <
type_traits
>
#include <
string
>
#include "./base.h"
namespace
tvm
{
...
...
@@ -28,7 +28,12 @@ using Halide::select;
using
Halide
::
Expr
;
using
Halide
::
Internal
::
Stmt
;
using
Var
=
Halide
::
VarExpr
;
class
Var
:
public
Halide
::
VarExpr
{
public
:
explicit
Var
(
const
std
::
string
&
name_hint
=
"v"
,
Type
t
=
Int
(
32
))
:
VarExpr
(
name_hint
,
t
)
{}
};
}
// namespace tvm
#endif // TVM_EXPR_H_
This diff is collapsed.
Click to expand it.
include/tvm/ir_mutator.h
0 → 100644
View file @
be8de13f
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.h
* \brief Defines general IRMutation pass
*/
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include "./expr.h"
namespace
tvm
{
namespace
ir
{
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class.
* \sa IRFunctor
*/
class
IRMutator
{
public
:
/*!
* \brief mutate expression
* \return the mutated expr
*/
virtual
Expr
mutate
(
Expr
expr
)
{
static
const
FMutateExpr
&
f
=
vtable_expr
();
return
f
(
expr
,
expr
,
this
);
}
/*!
* \brief mutate expression
* \return the mutated stmt
*/
virtual
Stmt
mutate
(
Stmt
stmt
)
{
static
const
FMutateStmt
&
f
=
vtable_stmt
();
return
f
(
stmt
,
stmt
,
this
);
}
/*! \brief destructor */
virtual
~
IRMutator
()
{}
/*! \brief functor type of expr mutation */
using
FMutateExpr
=
IRFunctor
<
Expr
(
const
IRNodeRef
&
,
const
Expr
&
,
IRMutator
*
)
>
;
/*! \brief functor type of stmt mutation */
using
FMutateStmt
=
IRFunctor
<
Stmt
(
const
IRNodeRef
&
,
const
Stmt
&
,
IRMutator
*
)
>
;
/*! \return internal vtable of expr */
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
/*! \return internal stmt of expr */
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
};
/*!
* \brief templatized base class of subclass of IRMutator
*
* Use "curiously recurring template pattern" to implement mutate for you.
* Child class need to declare IRMutatorBase<T>::vtable_expr and IRMutatorBase<T>::vtable_stmt
*
* \note This only implement direct subclass from IRMutator, similar code
* can be created to implement deeper subclassing when needed.
*/
class
IRMutatorExample
:
public
IRMutator
{
public
:
Expr
mutate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRMutatorExample
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
mutate
(
expr
));
}
Stmt
mutate
(
Stmt
stmt
)
final
{
static
const
FMutateStmt
&
f
=
IRMutatorExample
::
vtable_stmt
();
return
(
f
.
can_dispatch
(
stmt
)
?
f
(
stmt
,
stmt
,
this
)
:
IRMutator
::
mutate
(
stmt
));
}
// to be implemented by child class
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
static
FMutateStmt
&
vtable_stmt
();
// NOLINT(*)
};
}
// namespace ir
}
// namespace tvm
#endif // TVM_IR_MUTATOR_H_
This diff is collapsed.
Click to expand it.
src/lang/domain.cc
View file @
be8de13f
...
...
@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end)
// TODO(tqchen) add simplify to end - begin
}
Range
Range
::
make_with_min_extent
(
Expr
min
,
Expr
extent
)
{
return
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
min
,
extent
));
}
RDomain
::
RDomain
(
Domain
domain
)
{
std
::
vector
<
Var
>
index
;
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
...
...
@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) {
std
::
move
(
idx
),
std
::
move
(
domain
));
}
RDomain
RDomain
::
make
(
Array
<
Var
>
index
,
Domain
domain
)
{
return
RDomain
(
std
::
make_shared
<
RDomainNode
>
(
index
,
domain
));
}
TVM_REGISTER_NODE_TYPE
(
RDomainNode
);
}
// namespace tvm
This diff is collapsed.
Click to expand it.
src/lang/ir.cc
View file @
be8de13f
...
...
@@ -20,7 +20,7 @@ namespace Internal {
using
tvm
::
ir
::
Reduce
;
template
<>
void
ExprNode
<
Reduce
>::
accept
(
IRVisitor
*
v
)
const
{
void
ExprNode
<
Reduce
>::
accept
(
IRVisitor
*
v
,
const
Expr
&
)
const
{
LOG
(
FATAL
)
<<
"Reduce do not work with IRVisitor yet"
;
}
...
...
This diff is collapsed.
Click to expand it.
src/pass/ir_mutator.cc
0 → 100644
View file @
be8de13f
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
namespace
tvm
{
namespace
ir
{
IRMutator
::
FMutateExpr
&
IRMutator
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
IRMutator
::
FMutateStmt
&
IRMutator
::
vtable_stmt
()
{
// NOLINT(*)
static
FMutateStmt
inst
;
return
inst
;
}
// namespace to register the functors.
namespace
{
using
namespace
Halide
::
Internal
;
// const expr
inline
Expr
ReturnSelfExpr
(
const
IRNodeRef
&
,
const
Expr
&
e
,
IRMutator
*
)
{
return
e
;
}
inline
Array
<
Expr
>
MutateArray
(
Array
<
Expr
>
arr
,
IRMutator
*
m
)
{
std
::
vector
<
Expr
>
new_arr
(
arr
.
size
());
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
Expr
old_elem
=
arr
[
i
];
Expr
new_elem
=
m
->
mutate
(
old_elem
);
if
(
!
new_elem
.
same_as
(
old_elem
))
changed
=
true
;
new_arr
[
i
]
=
new_elem
;
}
if
(
!
changed
)
{
return
arr
;
}
else
{
return
Array
<
Expr
>
(
new_arr
);
}
}
inline
RDomain
MutateRDom
(
RDomain
rdom
,
IRMutator
*
m
)
{
std
::
vector
<
Range
>
new_dom
(
rdom
->
domain
.
size
());
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
Expr
new_min
=
m
->
mutate
(
r
->
min
);
Expr
new_extent
=
m
->
mutate
(
r
->
extent
);
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
new_dom
[
i
]
=
Range
::
make_with_min_extent
(
new_min
,
new_extent
);
}
if
(
!
changed
)
{
return
rdom
;
}
else
{
return
RDomain
::
make
(
rdom
->
index
,
Domain
(
new_dom
));
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
RDomain
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Expr
new_source
=
m
->
mutate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
op
->
source
.
same_as
(
new_source
))
{
return
e
;
}
else
{
return
Reduce
::
make
(
op
->
op
,
new_source
,
new_rdom
);
}
});
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
IntImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
UIntImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
FloatImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
StringImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
Variable
>
(
ReturnSelfExpr
);
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Cast
>
([](
const
Cast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
return
Cast
::
make
(
op
->
type
,
value
);
}
});
// binary operator
template
<
typename
T
>
inline
Expr
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
mutate
(
op
->
a
);
Expr
b
=
m
->
mutate
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
e
;
}
else
{
return
T
::
make
(
a
,
b
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
)
.
set_dispatch
<
EQ
>
(
Binary
<
EQ
>
)
.
set_dispatch
<
NE
>
(
Binary
<
NE
>
)
.
set_dispatch
<
LT
>
(
Binary
<
LT
>
)
.
set_dispatch
<
LE
>
(
Binary
<
LE
>
)
.
set_dispatch
<
GT
>
(
Binary
<
GT
>
)
.
set_dispatch
<
GE
>
(
Binary
<
GE
>
)
.
set_dispatch
<
And
>
(
Binary
<
And
>
)
.
set_dispatch
<
Or
>
(
Binary
<
Or
>
);
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Not
>
([](
const
Not
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
a
=
m
->
mutate
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
return
e
;
}
else
{
return
Not
::
make
(
a
);
}
})
.
set_dispatch
<
Select
>
([](
const
Select
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
cond
=
m
->
mutate
(
op
->
condition
);
Expr
t
=
m
->
mutate
(
op
->
true_value
);
Expr
f
=
m
->
mutate
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
return
e
;
}
else
{
return
Select
::
make
(
cond
,
t
,
f
);
}
})
.
set_dispatch
<
Load
>
([](
const
Load
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
index
=
m
->
mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
return
e
;
}
else
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
);
}
})
.
set_dispatch
<
Ramp
>
([](
const
Ramp
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
base
=
m
->
mutate
(
op
->
base
);
Expr
stride
=
m
->
mutate
(
op
->
stride
);
if
(
base
.
same_as
(
op
->
base
)
&&
stride
.
same_as
(
op
->
stride
))
{
return
e
;
}
else
{
return
Ramp
::
make
(
base
,
stride
,
op
->
lanes
);
}
})
.
set_dispatch
<
Broadcast
>
([](
const
Broadcast
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
e
;
}
else
{
return
Broadcast
::
make
(
value
,
op
->
lanes
);
}
})
.
set_dispatch
<
Call
>
([](
const
Call
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
auto
new_args
=
MutateArray
(
op
->
args
,
m
);
if
(
op
->
args
.
same_as
(
new_args
))
{
return
e
;
}
else
{
return
Call
::
make
(
op
->
type
,
op
->
name
,
new_args
,
op
->
call_type
,
op
->
func
,
op
->
value_index
);
}
})
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
Expr
body
=
m
->
mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
e
;
}
else
{
return
Let
::
make
(
op
->
var
,
value
,
body
);
}
});
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_stmt
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
Stmt
body
=
m
->
mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
LetStmt
::
make
(
op
->
var
,
value
,
body
);
}
})
.
set_dispatch
<
AssertStmt
>
([](
const
AssertStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
mutate
(
op
->
condition
);
Expr
message
=
m
->
mutate
(
op
->
message
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
message
.
same_as
(
op
->
message
))
{
return
s
;
}
else
{
return
AssertStmt
::
make
(
condition
,
message
);
}
})
.
set_dispatch
<
ProducerConsumer
>
([](
const
ProducerConsumer
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
body
=
m
->
mutate
(
op
->
body
);
if
(
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
ProducerConsumer
::
make
(
op
->
func
,
op
->
is_producer
,
body
);
}
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
min
=
m
->
mutate
(
op
->
min
);
Expr
extent
=
m
->
mutate
(
op
->
extent
);
Stmt
body
=
m
->
mutate
(
op
->
body
);
if
(
min
.
same_as
(
op
->
min
)
&&
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
For
::
make
(
op
->
loop_var
,
min
,
extent
,
op
->
for_type
,
op
->
device_api
,
body
);
}
})
.
set_dispatch
<
Store
>
([](
const
Store
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
mutate
(
op
->
value
);
Expr
index
=
m
->
mutate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
}
else
{
return
Store
::
make
(
op
->
buffer_var
,
value
,
index
);
}
})
.
set_dispatch
<
Provide
>
([](
const
Provide
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
auto
new_args
=
MutateArray
(
op
->
args
,
m
);
auto
new_values
=
MutateArray
(
op
->
values
,
m
);
if
(
op
->
args
.
same_as
(
new_args
)
&&
op
->
values
.
same_as
(
new_values
))
{
return
s
;
}
else
{
return
Provide
::
make
(
op
->
func
,
new_values
,
new_args
);
}
})
.
set_dispatch
<
Allocate
>
([](
const
Allocate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
std
::
vector
<
Expr
>
new_extents
;
bool
all_extents_unmodified
=
true
;
for
(
size_t
i
=
0
;
i
<
op
->
extents
.
size
();
i
++
)
{
new_extents
.
push_back
(
m
->
mutate
(
op
->
extents
[
i
]));
all_extents_unmodified
&=
new_extents
[
i
].
same_as
(
op
->
extents
[
i
]);
}
Stmt
body
=
m
->
mutate
(
op
->
body
);
Expr
condition
=
m
->
mutate
(
op
->
condition
);
Expr
new_expr
;
if
(
op
->
new_expr
.
defined
())
{
new_expr
=
m
->
mutate
(
op
->
new_expr
);
}
if
(
all_extents_unmodified
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
)
&&
new_expr
.
same_as
(
op
->
new_expr
))
{
return
s
;
}
else
{
return
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
new_extents
,
condition
,
body
,
new_expr
,
op
->
free_function
);
}
})
.
set_dispatch
<
Free
>
([](
const
Free
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
return
s
;
})
.
set_dispatch
<
Realize
>
([](
const
Realize
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Region
new_bounds
;
bool
bounds_changed
=
false
;
// Mutate the bounds
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
Expr
old_min
=
op
->
bounds
[
i
]
->
min
;
Expr
old_extent
=
op
->
bounds
[
i
]
->
extent
;
Expr
new_min
=
m
->
mutate
(
old_min
);
Expr
new_extent
=
m
->
mutate
(
old_extent
);
if
(
!
new_min
.
same_as
(
old_min
))
bounds_changed
=
true
;
if
(
!
new_extent
.
same_as
(
old_extent
))
bounds_changed
=
true
;
new_bounds
.
push_back
(
Range
::
make_by_min_extent
(
new_min
,
new_extent
));
}
Stmt
body
=
m
->
mutate
(
op
->
body
);
Expr
condition
=
m
->
mutate
(
op
->
condition
);
if
(
!
bounds_changed
&&
body
.
same_as
(
op
->
body
)
&&
condition
.
same_as
(
op
->
condition
))
{
return
s
;
}
else
{
return
Realize
::
make
(
op
->
func
,
op
->
types
,
new_bounds
,
condition
,
body
);
}
})
.
set_dispatch
<
Block
>
([](
const
Block
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Stmt
first
=
m
->
mutate
(
op
->
first
);
Stmt
rest
=
m
->
mutate
(
op
->
rest
);
if
(
first
.
same_as
(
op
->
first
)
&&
rest
.
same_as
(
op
->
rest
))
{
return
s
;
}
else
{
return
Block
::
make
(
first
,
rest
);
}
})
.
set_dispatch
<
IfThenElse
>
([](
const
IfThenElse
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
condition
=
m
->
mutate
(
op
->
condition
);
Stmt
then_case
=
m
->
mutate
(
op
->
then_case
);
Stmt
else_case
=
m
->
mutate
(
op
->
else_case
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
s
;
}
else
{
return
IfThenElse
::
make
(
condition
,
then_case
,
else_case
);
}
})
.
set_dispatch
<
Evaluate
>
([](
const
Evaluate
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
v
=
m
->
mutate
(
op
->
value
);
if
(
v
.
same_as
(
op
->
value
))
{
return
s
;
}
else
{
return
Evaluate
::
make
(
v
);
}
});
}
// namespace
}
// namespace ir
}
// namespace tvm
This diff is collapsed.
Click to expand it.
tests/cpp/ir_mutator_test.cc
0 → 100644
View file @
be8de13f
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace
{
using
namespace
tvm
::
ir
;
using
namespace
Halide
::
Internal
;
using
namespace
Halide
;
// replace variable to constant
class
IRVar2Const
:
public
IRMutator
{
public
:
VarExpr
var
;
int
int_val
;
Expr
mutate
(
Expr
expr
)
final
{
static
const
FMutateExpr
&
f
=
IRVar2Const
::
vtable_expr
();
return
(
f
.
can_dispatch
(
expr
)
?
f
(
expr
,
expr
,
this
)
:
IRMutator
::
mutate
(
expr
));
}
static
FMutateExpr
&
vtable_expr
();
};
// implement vtable
IRMutator
::
FMutateExpr
&
IRVar2Const
::
vtable_expr
()
{
// NOLINT(*)
static
FMutateExpr
inst
;
return
inst
;
}
TVM_STATIC_IR_FUNCTOR
(
IRVar2Const
,
vtable_expr
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
IRVar2Const
*
vm
=
static_cast
<
IRVar2Const
*>
(
m
);
if
(
e
.
same_as
(
vm
->
var
))
{
return
IntImm
::
make
(
Int
(
32
),
vm
->
int_val
);
}
else
{
return
e
;
}
});
}
// namespace
TEST
(
IRMutator
,
Basic
)
{
using
namespace
Halide
::
Internal
;
using
namespace
tvm
;
Var
x
(
"x"
),
y
;
auto
z
=
x
+
y
;
IRVar2Const
mu
;
mu
.
var
=
y
;
mu
.
int_val
=
10
;
auto
zz
=
mu
.
mutate
(
z
);
std
::
ostringstream
os
;
os
<<
zz
;
CHECK
(
os
.
str
()
==
"(x + 10)"
);
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment