Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e42cc112
Commit
e42cc112
authored
7 years ago
by
Tianqi Chen
Committed by
GitHub
7 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] UnrollLoop, isolate arithmetic module. (#32)
parent
d89917b6
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
259 additions
and
159 deletions
+259
-159
include/tvm/ir_pass.h
+9
-1
include/tvm/runtime/packed_func.h
+1
-1
python/tvm/build.py
+0
-1
src/README.md
+1
-0
src/api/api_pass.cc
+10
-0
src/arithmetic/compute_expr.h
+5
-5
src/arithmetic/int_set.cc
+12
-84
src/arithmetic/int_set.h
+17
-58
src/pass/inline.cc
+16
-2
src/pass/simple_passes.cc
+2
-2
src/pass/unroll_loop.cc
+78
-0
src/schedule/bound.cc
+77
-1
src/schedule/graph.cc
+0
-1
src/schedule/schedule_ops.cc
+11
-3
tests/python/unittest/test_pass_unroll.py
+20
-0
No files found.
include/tvm/ir_pass.h
View file @
e42cc112
...
...
@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Iter
Var
,
Expr
>&
value_map
);
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Var
,
Expr
>&
value_map
);
/*!
* \brief inline all calls of f in stmt.
...
...
@@ -98,6 +98,13 @@ Stmt StorageFlatten(Stmt stmt,
Map
<
Tensor
,
Buffer
>
extern_buffer
);
/*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
...
...
@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc
StorageSync
(
LoweredFunc
stmt
,
std
::
string
storage_scope
);
}
// namespace ir
}
// namespace tvm
...
...
This diff is collapsed.
Click to expand it.
include/tvm/runtime/packed_func.h
View file @
e42cc112
...
...
@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT
(
i
,
num_args
)
<<
"not enough argument passed, "
<<
num_args
<<
" passed"
<<
"
but request arg"
<<
i
;
<<
"
but request arg["
<<
i
<<
"]."
;
return
TVMArgValue
(
values
[
i
],
type_codes
[
i
]);
}
...
...
This diff is collapsed.
Click to expand it.
python/tvm/build.py
View file @
e42cc112
...
...
@@ -70,7 +70,6 @@ def build(sch,
fsplits
=
[
x
for
x
in
fsplits
]
for
i
in
range
(
1
,
len
(
fsplits
)):
fsplits
[
i
]
=
ir_pass
.
StorageSync
(
fsplits
[
i
],
"shared"
)
fsplits
[
i
]
=
ir_pass
.
StorageSync
(
fsplits
[
i
],
"global"
)
if
record_codes
is
not
None
:
output_ssa
=
False
...
...
This diff is collapsed.
Click to expand it.
src/README.md
View file @
e42cc112
...
...
@@ -3,5 +3,6 @@
-
api API functionr registration
-
lang The definition of DSL related data structure
-
schedule The operations on the schedule graph before converting to IR.
-
arithmetic Arithmetic expression and set simplification
-
pass The optimization pass on the IR structure
-
runtime Minimum runtime related codes.
This diff is collapsed.
Click to expand it.
src/api/api_pass.cc
View file @
e42cc112
...
...
@@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
namespace
tvm
{
...
...
@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
}
});
TVM_REGISTER_API
(
_pass_PostOrderVisit
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
PackedFunc
f
=
args
[
1
];
ir
::
PostOrderVisit
(
args
[
0
],
[
f
](
const
NodeRef
&
n
)
{
f
(
n
);
});
});
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
...
...
@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS2
(
StorageFlatten
);
REGISTER_PASS2
(
UnrollLoop
);
REGISTER_PASS2
(
StorageSync
);
REGISTER_PASS4
(
MakeAPI
);
REGISTER_PASS1
(
SplitHostDevice
);
...
...
This diff is collapsed.
Click to expand it.
src/
schedule
/compute_expr.h
→
src/
arithmetic
/compute_expr.h
View file @
e42cc112
...
...
@@ -4,14 +4,14 @@
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_
SCHEDULE
_COMPUTE_EXPR_H_
#define TVM_
SCHEDULE
_COMPUTE_EXPR_H_
#ifndef TVM_
ARITHMETIC
_COMPUTE_EXPR_H_
#define TVM_
ARITHMETIC
_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace
tvm
{
namespace
schedule
{
namespace
arith
{
using
Halide
::
Internal
::
add_would_overflow
;
using
Halide
::
Internal
::
sub_would_overflow
;
...
...
@@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return
Halide
::
Internal
::
Interval
::
make_min
(
a
,
b
);
}
}
// namespace
schedule
}
// namespace
arith
}
// namespace tvm
#endif // TVM_
SCHEDULE
_COMPUTE_EXPR_H_
#endif // TVM_
ARITHMETIC
_COMPUTE_EXPR_H_
This diff is collapsed.
Click to expand it.
src/
schedule
/int_set.cc
→
src/
arithmetic
/int_set.cc
View file @
e42cc112
/*!
* Copyright (c) 201
6
by Contributors
* \file int_set
_impl
.cc
* Copyright (c) 201
7
by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
...
...
@@ -10,7 +10,7 @@
#include "./compute_expr.h"
namespace
tvm
{
namespace
schedule
{
namespace
arith
{
using
Halide
::
Internal
::
Interval
;
...
...
@@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
return
(
s_int
&&
s_int
->
i
.
is_single_point
());
}
Expr
IntSet
::
point_value
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
&&
s_int
->
i
.
is_single_point
());
return
s_int
->
i
.
min
;
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
make
(
Interval
::
everything
());
}
...
...
@@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
}
// Check if a is created from b.
inline
bool
MatchRange
(
const
IntSet
&
a
,
const
Range
&
b
)
{
bool
IntSet
::
match_range
(
const
Range
&
b
)
const
{
const
IntSet
&
a
=
*
this
;
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
...
...
@@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return
CombineSets
<
OP
>
(
a
,
b
);
}
// Implementation of Evaluations and passing.
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
)
{
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
MatchRange
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
MatchRange
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
*
parent
=
Combine
<
Add
>
(
Combine
<
Add
>
(
Combine
<
Mul
>
(
outer
,
IntSet
::
single_point
(
factor
)),
inner
),
IntSet
::
single_point
(
parent_min
));
}
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
)
{
CHECK
(
dom_map
.
count
(
s
->
outer
));
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
MatchRange
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
const
IntervalSet
*
fused_int
=
fused
.
as
<
IntervalSet
>
();
if
(
fused_int
&&
fused_int
->
i
.
is_single_point
())
{
Expr
value
=
fused_int
->
i
.
min
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
v_outer
=
value
/
factor
;
Expr
v_inner
=
value
%
factor
;
if
(
!
is_zero
(
outer_min
))
v_outer
=
v_outer
+
outer_min
;
if
(
!
is_zero
(
inner_min
))
v_inner
=
v_inner
+
inner_min
;
*
outer
=
IntSet
::
single_point
(
v_outer
);
*
inner
=
IntSet
::
single_point
(
v_inner
);
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
rebased
,
IntSet
*
parent
)
{
CHECK
(
dom_map
.
count
(
s
->
parent
));
if
(
MatchRange
(
rebased
,
dom_map
.
at
(
s
->
rebased
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
Combine
<
Add
>
(
rebased
,
IntSet
::
single_point
(
parent_min
));
}
// Evaluator to evalute the epxression.
class
IntSetEvaluator
{
public
:
...
...
@@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
}
// namespace
schedule
}
// namespace
arith
}
// namespace tvm
This diff is collapsed.
Click to expand it.
src/
schedule
/int_set.h
→
src/
arithmetic
/int_set.h
View file @
e42cc112
...
...
@@ -3,14 +3,14 @@
* \file int_set.h
* \brief Abstraction for all integer set operations.
*/
#ifndef TVM_
SCHEDULE
_INT_SET_H_
#define TVM_
SCHEDULE
_INT_SET_H_
#ifndef TVM_
ARITHMETIC
_INT_SET_H_
#define TVM_
ARITHMETIC
_INT_SET_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace
tvm
{
namespace
schedule
{
namespace
arith
{
// internal node container of int set.
class
IntSetNode
;
...
...
@@ -44,6 +44,18 @@ class IntSet : public NodeRef {
bool
is_everything
()
const
;
/*! \return Whether the set is a single point */
bool
is_single_point
()
const
;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr
point_value
()
const
;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool
match_range
(
const
Range
&
r
)
const
;
/*! \return Whether the set contains everything */
static
IntSet
everything
();
/*!
...
...
@@ -89,59 +101,6 @@ IntSet EvalSet(Range r,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
parent
);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
...
...
@@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
}
}
// namespace
schedule
}
// namespace
arith
}
// namespace tvm
#endif // TVM_
SCHEDULE
_INT_SET_H_
#endif // TVM_
ARITHMETIC
_INT_SET_H_
This diff is collapsed.
Click to expand it.
src/pass/inline.cc
View file @
e42cc112
...
...
@@ -24,11 +24,25 @@ class IRInline : public IRMutator {
if
(
op
->
func
==
f_
)
{
CHECK_EQ
(
op
->
value_index
,
0
);
Expr
expr
=
body_
;
CHECK_EQ
(
args_
.
size
(),
op
->
args
.
size
())
<<
op
->
args
.
size
()
<<
" vs "
<<
args_
.
size
();
CHECK_EQ
(
args_
.
size
(),
op
->
args
.
size
());
bool
has_side_effect
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
if
(
HasSideEffect
(
op
->
args
[
i
]))
has_side_effect
=
true
;
}
if
(
has_side_effect
)
{
for
(
size_t
i
=
0
;
i
<
args_
.
size
();
++
i
)
{
expr
=
Let
::
make
(
args_
[
i
],
op
->
args
[
i
],
expr
);
}
}
else
{
Map
<
Var
,
Expr
>
vmap
;
for
(
size_t
i
=
0
;
i
<
args_
.
size
();
++
i
)
{
vmap
.
Set
(
args_
[
i
],
op
->
args
[
i
]);
}
expr
=
Substitute
(
Evaluate
::
make
(
expr
),
vmap
).
as
<
Evaluate
>
()
->
value
;
}
return
expr
;
}
else
{
return
e
;
...
...
This diff is collapsed.
Click to expand it.
src/pass/simple_passes.cc
View file @
e42cc112
...
...
@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std
::
unordered_map
<
const
Variable
*
,
Expr
>
smap
;
};
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Iter
Var
,
Expr
>&
value_map
)
{
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
Var
,
Expr
>&
value_map
)
{
IRSubstitue
m
;
for
(
auto
kv
:
value_map
)
{
m
.
smap
[
kv
.
first
->
var
.
get
()]
=
kv
.
second
;
m
.
smap
[
kv
.
first
.
get
()]
=
kv
.
second
;
}
return
m
.
Mutate
(
stmt
);
}
...
...
This diff is collapsed.
Click to expand it.
src/pass/unroll_loop.cc
0 → 100644
View file @
e42cc112
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic//compute_expr.h"
namespace
tvm
{
namespace
ir
{
class
LoopUnroller
:
public
IRMutator
{
public
:
explicit
LoopUnroller
(
int
max_auto_step
)
:
max_auto_step_
(
max_auto_step
)
{
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
{
Stmt
stmt
=
s
;
// constant folding.
Expr
extent
=
ir
::
Simplify
(
op
->
extent
);
const
IntImm
*
v1
=
extent
.
as
<
IntImm
>
();
const
UIntImm
*
v2
=
extent
.
as
<
UIntImm
>
();
int
value
=
-
1
;
if
(
v1
!=
nullptr
)
{
value
=
static_cast
<
int
>
(
v1
->
value
);
}
if
(
v2
!=
nullptr
)
{
value
=
static_cast
<
int
>
(
v2
->
value
);
}
bool
allow_unroll
=
value
>=
0
&&
value
<=
max_auto_step_
;
if
(
op
->
for_type
==
ForType
::
Unrolled
)
{
CHECK_GE
(
value
,
0
)
<<
"Cannot unroll non-constant loop"
;
allow_unroll
=
true
;
}
if
(
allow_unroll
)
{
using
arith
::
ComputeExpr
;
if
(
value
==
0
)
return
Evaluate
::
make
(
0
);
Stmt
body
=
op
->
body
;
Map
<
Var
,
Expr
>
vmap
;
Stmt
unrolled
;
for
(
int
i
=
0
;
i
<
value
;
++
i
)
{
Var
lv
(
op
->
loop_var
.
node_
);
vmap
.
Set
(
lv
,
ComputeExpr
<
Add
>
(
op
->
min
,
make_const
(
op
->
loop_var
.
type
(),
i
)));
Stmt
step
=
Substitute
(
body
,
vmap
);
if
(
unrolled
.
defined
())
{
unrolled
=
Block
::
make
(
unrolled
,
step
);
}
else
{
unrolled
=
step
;
}
}
return
this
->
Mutate
(
unrolled
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
stmt
);
}
}
private
:
int
max_auto_step_
;
};
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
)
{
Stmt
ret
=
LoopUnroller
(
max_auto_step
).
Mutate
(
stmt
);
return
ConvertSSA
(
ret
);
}
}
// namespace ir
}
// namespace tvm
This diff is collapsed.
Click to expand it.
src/schedule/bound.cc
View file @
e42cc112
...
...
@@ -7,13 +7,15 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./graph.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
namespace
schedule
{
using
namespace
arith
;
// result = ceil((a / b)), both a and b are positive integer
inline
Expr
DivCeil
(
Expr
a
,
Expr
b
)
{
return
ir
::
Simplify
((
a
+
b
-
1
)
/
b
);
...
...
@@ -70,6 +72,80 @@ void PassDown(const Stage& s,
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
// Implementation of Evaluations and passing.
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
)
{
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
outer
.
match_range
(
dom_map
.
at
(
s
->
outer
))
&&
inner
.
match_range
(
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
*
parent
=
EvalSet
(
s
->
outer
->
var
*
factor
+
s
->
inner
->
var
+
parent_min
,
{{
s
->
outer
,
outer
},
{
s
->
inner
,
inner
}});
}
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
)
{
CHECK
(
dom_map
.
count
(
s
->
outer
));
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
fused
.
match_range
(
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
if
(
fused
.
is_single_point
())
{
Expr
value
=
fused
.
point_value
();
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
v_outer
=
value
/
factor
;
Expr
v_inner
=
value
%
factor
;
if
(
!
is_zero
(
outer_min
))
v_outer
=
v_outer
+
outer_min
;
if
(
!
is_zero
(
inner_min
))
v_inner
=
v_inner
+
inner_min
;
*
outer
=
IntSet
::
single_point
(
v_outer
);
*
inner
=
IntSet
::
single_point
(
v_inner
);
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
rebased
,
IntSet
*
parent
)
{
CHECK
(
dom_map
.
count
(
s
->
parent
));
if
(
rebased
.
match_range
(
dom_map
.
at
(
s
->
rebased
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
EvalSet
(
s
->
rebased
->
var
+
parent_min
,
{{
s
->
rebased
,
rebased
}});
}
void
PassUp
(
const
Stage
&
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
std
::
unordered_map
<
IterVar
,
IntSet
>*
p_state
)
{
...
...
This diff is collapsed.
Click to expand it.
src/schedule/graph.cc
View file @
e42cc112
...
...
@@ -6,7 +6,6 @@
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <unordered_set>
#include "./int_set.h"
#include "./graph.h"
namespace
tvm
{
...
...
This diff is collapsed.
Click to expand it.
src/schedule/schedule_ops.cc
View file @
e42cc112
...
...
@@ -9,13 +9,13 @@
#include <tvm/schedule_pass.h>
#include "../pass/ir_util.h"
#include ".
/int_set
.h"
#include ".
./arithmetic/compute_expr
.h"
#include "./graph.h"
#include "./compute_expr.h"
namespace
tvm
{
namespace
schedule
{
using
namespace
arith
;
using
namespace
ir
;
/*!
...
...
@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
return
nest
;
}
Stmt
Substitute
(
Stmt
s
,
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
)
{
Map
<
Var
,
Expr
>
temp
;
for
(
const
auto
&
kv
:
value_map
)
{
temp
.
Set
(
kv
.
first
->
var
,
kv
.
second
);
}
return
ir
::
Substitute
(
s
,
temp
);
}
Stmt
MakeLoop
(
const
Stage
&
s
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
provide
,
...
...
@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
auto
nest
=
MakeLoopNest
(
s
,
dom_map
,
0
,
false
,
bound_state
,
{},
&
value_map
);
provide
=
Substitute
(
provide
,
value_map
);
if
(
init
.
defined
())
{
// try to find the location to insert the initialization.
...
...
This diff is collapsed.
Click to expand it.
tests/python/unittest/test_pass_unroll.py
0 → 100644
View file @
e42cc112
import
tvm
def
test_unroll_loop
():
dtype
=
'int64'
n
=
tvm
.
Var
(
'n'
)
Ab
=
tvm
.
Buffer
((
n
,
),
dtype
)
i
=
tvm
.
Var
(
'i'
)
j
=
tvm
.
Var
(
'j'
)
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
n
,
2
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
n
,
0
,
0
,
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
j
+
1
)))
stmt
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
8
)
print
(
stmt
)
if
__name__
==
"__main__"
:
test_unroll_loop
()
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment