Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e42cc112
Commit
e42cc112
authored
Feb 04, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 04, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] UnrollLoop, isolate arithmetic module. (#32)
parent
d89917b6
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
261 additions
and
161 deletions
+261
-161
include/tvm/ir_pass.h
+9
-1
include/tvm/runtime/packed_func.h
+1
-1
python/tvm/build.py
+0
-1
src/README.md
+1
-0
src/api/api_pass.cc
+10
-0
src/arithmetic/compute_expr.h
+5
-5
src/arithmetic/int_set.cc
+12
-84
src/arithmetic/int_set.h
+17
-58
src/pass/inline.cc
+18
-4
src/pass/simple_passes.cc
+2
-2
src/pass/unroll_loop.cc
+78
-0
src/schedule/bound.cc
+77
-1
src/schedule/graph.cc
+0
-1
src/schedule/schedule_ops.cc
+11
-3
tests/python/unittest/test_pass_unroll.py
+20
-0
No files found.
include/tvm/ir_pass.h
View file @
e42cc112
...
...
@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Iter
Var
,
Expr
>&
value_map
);
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Var
,
Expr
>&
value_map
);
/*!
* \brief inline all calls of f in stmt.
...
...
@@ -98,6 +98,13 @@ Stmt StorageFlatten(Stmt stmt,
Map
<
Tensor
,
Buffer
>
extern_buffer
);
/*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc
StorageSync
(
LoweredFunc
stmt
,
std
::
string
storage_scope
);
}
// namespace ir
}
// namespace tvm
...
...
include/tvm/runtime/packed_func.h
View file @
e42cc112
...
...
@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT
(
i
,
num_args
)
<<
"not enough argument passed, "
<<
num_args
<<
" passed"
<<
"
but request arg"
<<
i
;
<<
"
but request arg["
<<
i
<<
"]."
;
return
TVMArgValue
(
values
[
i
],
type_codes
[
i
]);
}
...
...
python/tvm/build.py
View file @
e42cc112
...
...
@@ -70,7 +70,6 @@ def build(sch,
fsplits
=
[
x
for
x
in
fsplits
]
for
i
in
range
(
1
,
len
(
fsplits
)):
fsplits
[
i
]
=
ir_pass
.
StorageSync
(
fsplits
[
i
],
"shared"
)
fsplits
[
i
]
=
ir_pass
.
StorageSync
(
fsplits
[
i
],
"global"
)
if
record_codes
is
not
None
:
output_ssa
=
False
...
...
src/README.md
View file @
e42cc112
...
...
@@ -3,5 +3,6 @@
-
api API functionr registration
-
lang The definition of DSL related data structure
-
schedule The operations on the schedule graph before converting to IR.
-
arithmetic Arithmetic expression and set simplification
-
pass The optimization pass on the IR structure
-
runtime Minimum runtime related codes.
src/api/api_pass.cc
View file @
e42cc112
...
...
@@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
namespace
tvm
{
...
...
@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
}
});
TVM_REGISTER_API
(
_pass_PostOrderVisit
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
PackedFunc
f
=
args
[
1
];
ir
::
PostOrderVisit
(
args
[
0
],
[
f
](
const
NodeRef
&
n
)
{
f
(
n
);
});
});
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
...
...
@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS2
(
StorageFlatten
);
REGISTER_PASS2
(
UnrollLoop
);
REGISTER_PASS2
(
StorageSync
);
REGISTER_PASS4
(
MakeAPI
);
REGISTER_PASS1
(
SplitHostDevice
);
...
...
src/
schedule
/compute_expr.h
→
src/
arithmetic
/compute_expr.h
View file @
e42cc112
...
...
@@ -4,14 +4,14 @@
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_
SCHEDULE
_COMPUTE_EXPR_H_
#define TVM_
SCHEDULE
_COMPUTE_EXPR_H_
#ifndef TVM_
ARITHMETIC
_COMPUTE_EXPR_H_
#define TVM_
ARITHMETIC
_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace
tvm
{
namespace
schedule
{
namespace
arith
{
using
Halide
::
Internal
::
add_would_overflow
;
using
Halide
::
Internal
::
sub_would_overflow
;
...
...
@@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return
Halide
::
Internal
::
Interval
::
make_min
(
a
,
b
);
}
}
// namespace
schedule
}
// namespace
arith
}
// namespace tvm
#endif // TVM_
SCHEDULE
_COMPUTE_EXPR_H_
#endif // TVM_
ARITHMETIC
_COMPUTE_EXPR_H_
src/
schedule
/int_set.cc
→
src/
arithmetic
/int_set.cc
View file @
e42cc112
/*!
* Copyright (c) 201
6
by Contributors
* \file int_set
_impl
.cc
* Copyright (c) 201
7
by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
...
...
@@ -10,7 +10,7 @@
#include "./compute_expr.h"
namespace
tvm
{
namespace
schedule
{
namespace
arith
{
using
Halide
::
Internal
::
Interval
;
...
...
@@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
return
(
s_int
&&
s_int
->
i
.
is_single_point
());
}
Expr
IntSet
::
point_value
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
&&
s_int
->
i
.
is_single_point
());
return
s_int
->
i
.
min
;
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
make
(
Interval
::
everything
());
}
...
...
@@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
}
// Check if a is created from b.
inline
bool
MatchRange
(
const
IntSet
&
a
,
const
Range
&
b
)
{
bool
IntSet
::
match_range
(
const
Range
&
b
)
const
{
const
IntSet
&
a
=
*
this
;
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
...
...
@@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return
CombineSets
<
OP
>
(
a
,
b
);
}
// Implementation of Evaluations and passing.
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
)
{
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
MatchRange
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
MatchRange
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
*
parent
=
Combine
<
Add
>
(
Combine
<
Add
>
(
Combine
<
Mul
>
(
outer
,
IntSet
::
single_point
(
factor
)),
inner
),
IntSet
::
single_point
(
parent_min
));
}
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
)
{
CHECK
(
dom_map
.
count
(
s
->
outer
));
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
MatchRange
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
const
IntervalSet
*
fused_int
=
fused
.
as
<
IntervalSet
>
();
if
(
fused_int
&&
fused_int
->
i
.
is_single_point
())
{
Expr
value
=
fused_int
->
i
.
min
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
v_outer
=
value
/
factor
;
Expr
v_inner
=
value
%
factor
;
if
(
!
is_zero
(
outer_min
))
v_outer
=
v_outer
+
outer_min
;
if
(
!
is_zero
(
inner_min
))
v_inner
=
v_inner
+
inner_min
;
*
outer
=
IntSet
::
single_point
(
v_outer
);
*
inner
=
IntSet
::
single_point
(
v_inner
);
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
rebased
,
IntSet
*
parent
)
{
CHECK
(
dom_map
.
count
(
s
->
parent
));
if
(
MatchRange
(
rebased
,
dom_map
.
at
(
s
->
rebased
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
Combine
<
Add
>
(
rebased
,
IntSet
::
single_point
(
parent_min
));
}
// Evaluator to evalute the epxression.
class
IntSetEvaluator
{
public
:
...
...
@@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
}
// namespace
schedule
}
// namespace
arith
}
// namespace tvm
src/
schedule
/int_set.h
→
src/
arithmetic
/int_set.h
View file @
e42cc112
...
...
@@ -3,14 +3,14 @@
* \file int_set.h
* \brief Abstraction for all integer set operations.
*/
#ifndef TVM_
SCHEDULE
_INT_SET_H_
#define TVM_
SCHEDULE
_INT_SET_H_
#ifndef TVM_
ARITHMETIC
_INT_SET_H_
#define TVM_
ARITHMETIC
_INT_SET_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace
tvm
{
namespace
schedule
{
namespace
arith
{
// internal node container of int set.
class
IntSetNode
;
...
...
@@ -44,6 +44,18 @@ class IntSet : public NodeRef {
bool
is_everything
()
const
;
/*! \return Whether the set is a single point */
bool
is_single_point
()
const
;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr
point_value
()
const
;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool
match_range
(
const
Range
&
r
)
const
;
/*! \return Whether the set contains everything */
static
IntSet
everything
();
/*!
...
...
@@ -89,59 +101,6 @@ IntSet EvalSet(Range r,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
parent
);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
...
...
@@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
}
}
// namespace
schedule
}
// namespace
arith
}
// namespace tvm
#endif // TVM_
SCHEDULE
_INT_SET_H_
#endif // TVM_
ARITHMETIC
_INT_SET_H_
src/pass/inline.cc
View file @
e42cc112
...
...
@@ -24,10 +24,24 @@ class IRInline : public IRMutator {
if
(
op
->
func
==
f_
)
{
CHECK_EQ
(
op
->
value_index
,
0
);
Expr
expr
=
body_
;
CHECK_EQ
(
args_
.
size
(),
op
->
args
.
size
())
<<
op
->
args
.
size
()
<<
" vs "
<<
args_
.
size
();
for
(
size_t
i
=
0
;
i
<
args_
.
size
();
++
i
)
{
expr
=
Let
::
make
(
args_
[
i
],
op
->
args
[
i
],
expr
);
CHECK_EQ
(
args_
.
size
(),
op
->
args
.
size
());
bool
has_side_effect
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
if
(
HasSideEffect
(
op
->
args
[
i
]))
has_side_effect
=
true
;
}
if
(
has_side_effect
)
{
for
(
size_t
i
=
0
;
i
<
args_
.
size
();
++
i
)
{
expr
=
Let
::
make
(
args_
[
i
],
op
->
args
[
i
],
expr
);
}
}
else
{
Map
<
Var
,
Expr
>
vmap
;
for
(
size_t
i
=
0
;
i
<
args_
.
size
();
++
i
)
{
vmap
.
Set
(
args_
[
i
],
op
->
args
[
i
]);
}
expr
=
Substitute
(
Evaluate
::
make
(
expr
),
vmap
).
as
<
Evaluate
>
()
->
value
;
}
return
expr
;
}
else
{
...
...
src/pass/simple_passes.cc
View file @
e42cc112
...
...
@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std
::
unordered_map
<
const
Variable
*
,
Expr
>
smap
;
};
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Iter
Var
,
Expr
>&
value_map
)
{
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Var
,
Expr
>&
value_map
)
{
IRSubstitue
m
;
for
(
auto
kv
:
value_map
)
{
m
.
smap
[
kv
.
first
->
var
.
get
()]
=
kv
.
second
;
m
.
smap
[
kv
.
first
.
get
()]
=
kv
.
second
;
}
return
m
.
Mutate
(
stmt
);
}
...
...
src/pass/unroll_loop.cc
0 → 100644
View file @
e42cc112
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic//compute_expr.h"
namespace
tvm
{
namespace
ir
{
class
LoopUnroller
:
public
IRMutator
{
public
:
explicit
LoopUnroller
(
int
max_auto_step
)
:
max_auto_step_
(
max_auto_step
)
{
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
{
Stmt
stmt
=
s
;
// constant folding.
Expr
extent
=
ir
::
Simplify
(
op
->
extent
);
const
IntImm
*
v1
=
extent
.
as
<
IntImm
>
();
const
UIntImm
*
v2
=
extent
.
as
<
UIntImm
>
();
int
value
=
-
1
;
if
(
v1
!=
nullptr
)
{
value
=
static_cast
<
int
>
(
v1
->
value
);
}
if
(
v2
!=
nullptr
)
{
value
=
static_cast
<
int
>
(
v2
->
value
);
}
bool
allow_unroll
=
value
>=
0
&&
value
<=
max_auto_step_
;
if
(
op
->
for_type
==
ForType
::
Unrolled
)
{
CHECK_GE
(
value
,
0
)
<<
"Cannot unroll non-constant loop"
;
allow_unroll
=
true
;
}
if
(
allow_unroll
)
{
using
arith
::
ComputeExpr
;
if
(
value
==
0
)
return
Evaluate
::
make
(
0
);
Stmt
body
=
op
->
body
;
Map
<
Var
,
Expr
>
vmap
;
Stmt
unrolled
;
for
(
int
i
=
0
;
i
<
value
;
++
i
)
{
Var
lv
(
op
->
loop_var
.
node_
);
vmap
.
Set
(
lv
,
ComputeExpr
<
Add
>
(
op
->
min
,
make_const
(
op
->
loop_var
.
type
(),
i
)));
Stmt
step
=
Substitute
(
body
,
vmap
);
if
(
unrolled
.
defined
())
{
unrolled
=
Block
::
make
(
unrolled
,
step
);
}
else
{
unrolled
=
step
;
}
}
return
this
->
Mutate
(
unrolled
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
stmt
);
}
}
private
:
int
max_auto_step_
;
};
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
)
{
Stmt
ret
=
LoopUnroller
(
max_auto_step
).
Mutate
(
stmt
);
return
ConvertSSA
(
ret
);
}
}
// namespace ir
}
// namespace tvm
src/schedule/bound.cc
View file @
e42cc112
...
...
@@ -7,13 +7,15 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./graph.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
namespace
schedule
{
using
namespace
arith
;
// result = ceil((a / b)), both a and b are positive integer
inline
Expr
DivCeil
(
Expr
a
,
Expr
b
)
{
return
ir
::
Simplify
((
a
+
b
-
1
)
/
b
);
...
...
@@ -70,6 +72,80 @@ void PassDown(const Stage& s,
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
// Implementation of Evaluations and passing.
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
)
{
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
outer
.
match_range
(
dom_map
.
at
(
s
->
outer
))
&&
inner
.
match_range
(
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
*
parent
=
EvalSet
(
s
->
outer
->
var
*
factor
+
s
->
inner
->
var
+
parent_min
,
{{
s
->
outer
,
outer
},
{
s
->
inner
,
inner
}});
}
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
)
{
CHECK
(
dom_map
.
count
(
s
->
outer
));
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
fused
.
match_range
(
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
if
(
fused
.
is_single_point
())
{
Expr
value
=
fused
.
point_value
();
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
v_outer
=
value
/
factor
;
Expr
v_inner
=
value
%
factor
;
if
(
!
is_zero
(
outer_min
))
v_outer
=
v_outer
+
outer_min
;
if
(
!
is_zero
(
inner_min
))
v_inner
=
v_inner
+
inner_min
;
*
outer
=
IntSet
::
single_point
(
v_outer
);
*
inner
=
IntSet
::
single_point
(
v_inner
);
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
rebased
,
IntSet
*
parent
)
{
CHECK
(
dom_map
.
count
(
s
->
parent
));
if
(
rebased
.
match_range
(
dom_map
.
at
(
s
->
rebased
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
EvalSet
(
s
->
rebased
->
var
+
parent_min
,
{{
s
->
rebased
,
rebased
}});
}
void
PassUp
(
const
Stage
&
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
std
::
unordered_map
<
IterVar
,
IntSet
>*
p_state
)
{
...
...
src/schedule/graph.cc
View file @
e42cc112
...
...
@@ -6,7 +6,6 @@
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <unordered_set>
#include "./int_set.h"
#include "./graph.h"
namespace
tvm
{
...
...
src/schedule/schedule_ops.cc
View file @
e42cc112
...
...
@@ -9,13 +9,13 @@
#include <tvm/schedule_pass.h>
#include "../pass/ir_util.h"
#include ".
/int_set
.h"
#include ".
./arithmetic/compute_expr
.h"
#include "./graph.h"
#include "./compute_expr.h"
namespace
tvm
{
namespace
schedule
{
using
namespace
arith
;
using
namespace
ir
;
/*!
...
...
@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
return
nest
;
}
Stmt
Substitute
(
Stmt
s
,
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
)
{
Map
<
Var
,
Expr
>
temp
;
for
(
const
auto
&
kv
:
value_map
)
{
temp
.
Set
(
kv
.
first
->
var
,
kv
.
second
);
}
return
ir
::
Substitute
(
s
,
temp
);
}
Stmt
MakeLoop
(
const
Stage
&
s
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
provide
,
...
...
@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
auto
nest
=
MakeLoopNest
(
s
,
dom_map
,
0
,
false
,
bound_state
,
{},
&
value_map
);
provide
=
Substitute
(
provide
,
value_map
);
if
(
init
.
defined
())
{
// try to find the location to insert the initialization.
...
...
tests/python/unittest/test_pass_unroll.py
0 → 100644
View file @
e42cc112
import
tvm
def
test_unroll_loop
():
dtype
=
'int64'
n
=
tvm
.
Var
(
'n'
)
Ab
=
tvm
.
Buffer
((
n
,
),
dtype
)
i
=
tvm
.
Var
(
'i'
)
j
=
tvm
.
Var
(
'j'
)
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
n
,
2
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
n
,
0
,
0
,
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
j
+
1
)))
stmt
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
8
)
print
(
stmt
)
if
__name__
==
"__main__"
:
test_unroll_loop
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment