Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1400edac
Commit
1400edac
authored
Jun 18, 2017
by
Tianqi Chen
Committed by
GitHub
Jun 18, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Include PrefetchIR (#189)
parent
eaf0fde3
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
77 additions
and
10 deletions
+77
-10
HalideIR
+1
-1
include/tvm/ir.h
+6
-0
include/tvm/ir_functor_ext.h
+5
-0
include/tvm/ir_mutator.h
+1
-0
include/tvm/ir_visitor.h
+1
-0
include/tvm/schedule.h
+6
-0
src/README.md
+1
-0
src/arithmetic/bound_deducer.cc
+1
-1
src/op/op_util.cc
+18
-2
src/pass/inject_virtual_thread.cc
+1
-1
src/pass/inline.cc
+1
-1
src/pass/ir_mutator.cc
+25
-0
src/pass/ir_visitor.cc
+7
-1
src/pass/loop_partition.cc
+1
-1
src/pass/lower_thread_allreduce.cc
+1
-1
src/pass/storage_rewrite.cc
+1
-1
No files found.
HalideIR
@
41fe60a7
Subproject commit
efe5b5cc3c89da5d5e39570f6776d39d8acacacc
Subproject commit
41fe60a76fe6e5669540acf1ef3595bc38025157
include/tvm/ir.h
View file @
1400edac
...
@@ -158,6 +158,11 @@ constexpr const char* device_context_type = "device_context_type";
...
@@ -158,6 +158,11 @@ constexpr const char* device_context_type = "device_context_type";
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
/*! \brief Mark of reduce scope */
/*! \brief Mark of reduce scope */
constexpr
const
char
*
reduce_scope
=
"reduce_scope"
;
constexpr
const
char
*
reduce_scope
=
"reduce_scope"
;
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
*/
constexpr
const
char
*
prefetch_scope
=
"prefetch_scope"
;
/*! \brief Mark of scan update scope */
/*! \brief Mark of scan update scope */
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
/*! \brief Mark of scan init scope */
/*! \brief Mark of scan init scope */
...
@@ -371,6 +376,7 @@ using Halide::Internal::Provide;
...
@@ -371,6 +376,7 @@ using Halide::Internal::Provide;
using
Halide
::
Internal
::
Allocate
;
using
Halide
::
Internal
::
Allocate
;
using
Halide
::
Internal
::
Free
;
using
Halide
::
Internal
::
Free
;
using
Halide
::
Internal
::
Realize
;
using
Halide
::
Internal
::
Realize
;
using
Halide
::
Internal
::
Prefetch
;
using
Halide
::
Internal
::
Block
;
using
Halide
::
Internal
::
Block
;
using
Halide
::
Internal
::
IfThenElse
;
using
Halide
::
Internal
::
IfThenElse
;
using
Halide
::
Internal
::
Evaluate
;
using
Halide
::
Internal
::
Evaluate
;
...
...
include/tvm/ir_functor_ext.h
View file @
1400edac
...
@@ -17,6 +17,9 @@ namespace ir {
...
@@ -17,6 +17,9 @@ namespace ir {
* You can use this as a more powerful Visitor, since it allows you to
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
* define function signatures of Visit Function.
*
*
* This helps you to avoid to book-keep return value of Visitor via state,
* which can cause bugs easily when state is incorrectly maintained.
*
* \code
* \code
* // A functor that set variable to b. and calculate results.
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
* class MyExprFunctor
...
@@ -223,6 +226,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
...
@@ -223,6 +226,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual
R
VisitStmt_
(
const
ProducerConsumer
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
ProducerConsumer
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Provide
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Provide
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Realize
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Realize
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Prefetch
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Block
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Block
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Evaluate
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmt_
(
const
Evaluate
*
op
,
Args
...
args
)
STMT_FUNCTOR_DEFAULT
;
virtual
R
VisitStmtDefault_
(
const
Node
*
op
,
Args
...)
{
virtual
R
VisitStmtDefault_
(
const
Node
*
op
,
Args
...)
{
...
@@ -245,6 +249,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
...
@@ -245,6 +249,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH
(
ProducerConsumer
);
IR_STMT_FUNCTOR_DISPATCH
(
ProducerConsumer
);
IR_STMT_FUNCTOR_DISPATCH
(
Provide
);
IR_STMT_FUNCTOR_DISPATCH
(
Provide
);
IR_STMT_FUNCTOR_DISPATCH
(
Realize
);
IR_STMT_FUNCTOR_DISPATCH
(
Realize
);
IR_STMT_FUNCTOR_DISPATCH
(
Prefetch
);
IR_STMT_FUNCTOR_DISPATCH
(
Block
);
IR_STMT_FUNCTOR_DISPATCH
(
Block
);
IR_STMT_FUNCTOR_DISPATCH
(
Evaluate
);
IR_STMT_FUNCTOR_DISPATCH
(
Evaluate
);
return
vtable
;
return
vtable
;
...
...
include/tvm/ir_mutator.h
View file @
1400edac
...
@@ -66,6 +66,7 @@ class IRMutator {
...
@@ -66,6 +66,7 @@ class IRMutator {
virtual
Stmt
Mutate_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Realize
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Realize
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Prefetch
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Evaluate
*
op
,
const
Stmt
&
s
);
virtual
Stmt
Mutate_
(
const
Evaluate
*
op
,
const
Stmt
&
s
);
...
...
include/tvm/ir_visitor.h
View file @
1400edac
...
@@ -116,6 +116,7 @@ class IRVisitor {
...
@@ -116,6 +116,7 @@ class IRVisitor {
virtual
void
Visit_
(
const
ProducerConsumer
*
op
);
virtual
void
Visit_
(
const
ProducerConsumer
*
op
);
virtual
void
Visit_
(
const
Provide
*
op
);
virtual
void
Visit_
(
const
Provide
*
op
);
virtual
void
Visit_
(
const
Realize
*
op
);
virtual
void
Visit_
(
const
Realize
*
op
);
virtual
void
Visit_
(
const
Prefetch
*
op
);
virtual
void
Visit_
(
const
Block
*
op
);
virtual
void
Visit_
(
const
Block
*
op
);
virtual
void
Visit_
(
const
Evaluate
*
op
);
virtual
void
Visit_
(
const
Evaluate
*
op
);
virtual
void
Visit_
(
const
IntImm
*
op
);
virtual
void
Visit_
(
const
IntImm
*
op
);
...
...
include/tvm/schedule.h
View file @
1400edac
...
@@ -461,10 +461,16 @@ class IterVarAttrNode : public Node {
...
@@ -461,10 +461,16 @@ class IterVarAttrNode : public Node {
IterVarType
iter_type
{
kDataPar
};
IterVarType
iter_type
{
kDataPar
};
/*! \brief The thread this iter Var binds, can be null */
/*! \brief The thread this iter Var binds, can be null */
IterVar
bind_thread
;
IterVar
bind_thread
;
/*! \brief List of tensor to be prefetched in this loop */
Array
<
Tensor
>
prefetch_data
;
/*! \brief The offset used in each prefetch */
Array
<
Expr
>
prefetch_offset
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"iter_type"
,
&
iter_type
);
v
->
Visit
(
"iter_type"
,
&
iter_type
);
v
->
Visit
(
"bind_thread"
,
&
bind_thread
);
v
->
Visit
(
"bind_thread"
,
&
bind_thread
);
v
->
Visit
(
"prefetch_data"
,
&
prefetch_data
);
v
->
Visit
(
"prefetch_offset"
,
&
prefetch_offset
);
}
}
static
constexpr
const
char
*
_type_key
=
"IterVarAttr"
;
static
constexpr
const
char
*
_type_key
=
"IterVarAttr"
;
...
...
src/README.md
View file @
1400edac
...
@@ -13,3 +13,4 @@ There can be internal header files within each module that sit in src.
...
@@ -13,3 +13,4 @@ There can be internal header files within each module that sit in src.
-
pass The optimization pass on the IR structure
-
pass The optimization pass on the IR structure
-
codegen The code generator.
-
codegen The code generator.
-
runtime Minimum runtime related codes
-
runtime Minimum runtime related codes
-
contrib Contrib extension libraries
src/arithmetic/bound_deducer.cc
View file @
1400edac
...
@@ -212,7 +212,7 @@ void BoundDeducer::Deduce() {
...
@@ -212,7 +212,7 @@ void BoundDeducer::Deduce() {
success
=
false
;
success
=
false
;
return
;
return
;
}
}
// get the sign of every subexpr
expr_map_
=
EvalSetForEachSubExpr
(
expr_
,
hint_map_
);
expr_map_
=
EvalSetForEachSubExpr
(
expr_
,
hint_map_
);
Visit
(
expr_
);
Visit
(
expr_
);
...
...
src/op/op_util.cc
View file @
1400edac
...
@@ -55,14 +55,18 @@ MakeLoopNest(const Stage& stage,
...
@@ -55,14 +55,18 @@ MakeLoopNest(const Stage& stage,
// Mark the iter var in the IR, to remember the point
// Mark the iter var in the IR, to remember the point
if
(
bind_iv
->
thread_tag
.
length
()
==
0
)
{
if
(
bind_iv
->
thread_tag
.
length
()
==
0
)
{
ForType
for_type
=
ForType
::
Serial
;
ForType
for_type
=
ForType
::
Serial
;
IterVarAttr
it_attr
;
if
(
stage
->
iter_var_attrs
.
count
(
iv
))
{
if
(
stage
->
iter_var_attrs
.
count
(
iv
))
{
switch
(
stage
->
iter_var_attrs
[
iv
]
->
iter_type
)
{
it_attr
=
stage
->
iter_var_attrs
[
iv
];
}
if
(
it_attr
.
defined
())
{
switch
(
it_attr
->
iter_type
)
{
case
kUnrolled
:
for_type
=
ForType
::
Unrolled
;
break
;
case
kUnrolled
:
for_type
=
ForType
::
Unrolled
;
break
;
case
kVectorized
:
for_type
=
ForType
::
Vectorized
;
break
;
case
kVectorized
:
for_type
=
ForType
::
Vectorized
;
break
;
case
kParallelized
:
for_type
=
ForType
::
Parallel
;
break
;
case
kParallelized
:
for_type
=
ForType
::
Parallel
;
break
;
case
kDataPar
:
break
;
case
kDataPar
:
break
;
default
:
LOG
(
FATAL
)
<<
"Unknown iter type"
default
:
LOG
(
FATAL
)
<<
"Unknown iter type"
<<
stage
->
iter_var_attrs
[
iv
]
->
iter_type
<<
it_attr
->
iter_type
<<
" in the iter_var_attrs"
;
<<
" in the iter_var_attrs"
;
}
}
}
}
...
@@ -85,6 +89,18 @@ MakeLoopNest(const Stage& stage,
...
@@ -85,6 +89,18 @@ MakeLoopNest(const Stage& stage,
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
new_value
,
no_op
));
LetStmt
::
make
(
var
,
new_value
,
no_op
));
}
}
if
(
it_attr
.
defined
()
&&
it_attr
->
prefetch_data
.
size
()
!=
0
)
{
CHECK
(
!
is_one
(
dom
->
extent
))
<<
"Cannot prefetch on trivial loop with extent=1"
;
CHECK_EQ
(
it_attr
->
prefetch_data
.
size
(),
it_attr
->
prefetch_offset
.
size
());
for
(
size_t
i
=
0
;
i
<
it_attr
->
prefetch_data
.
size
();
++
i
)
{
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
it_attr
->
prefetch_data
[
i
],
ir
::
attr
::
prefetch_scope
,
it_attr
->
prefetch_offset
[
i
],
no_op
));
}
}
}
else
if
(
bind_iv
->
thread_tag
==
"vthread"
)
{
}
else
if
(
bind_iv
->
thread_tag
==
"vthread"
)
{
// virtual thread
// virtual thread
// Always restrict threaded IterVar to starts from 0.
// Always restrict threaded IterVar to starts from 0.
...
...
src/pass/inject_virtual_thread.cc
View file @
1400edac
...
@@ -13,7 +13,7 @@ namespace tvm {
...
@@ -13,7 +13,7 @@ namespace tvm {
namespace
ir
{
namespace
ir
{
// If expression is touched by var.
// If expression is touched by var.
class
ExprTouched
:
public
IRVisitor
{
class
ExprTouched
final
:
public
IRVisitor
{
public
:
public
:
explicit
ExprTouched
(
const
std
::
unordered_set
<
const
Variable
*>
&
touched
)
explicit
ExprTouched
(
const
std
::
unordered_set
<
const
Variable
*>
&
touched
)
:
touched_var_
(
touched
)
{}
:
touched_var_
(
touched
)
{}
...
...
src/pass/inline.cc
View file @
1400edac
...
@@ -12,7 +12,7 @@ namespace ir {
...
@@ -12,7 +12,7 @@ namespace ir {
// inliner to inline a function
// inliner to inline a function
// the result may not be SSA,
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
// ConvertSSA need to be applied after this pass
class
IRInline
:
public
IRMutator
{
class
IRInline
final
:
public
IRMutator
{
public
:
public
:
IRInline
(
FunctionRef
f
,
Array
<
Var
>
args
,
Expr
body
)
IRInline
(
FunctionRef
f
,
Array
<
Var
>
args
,
Expr
body
)
:
f_
(
f
),
args_
(
args
),
body_
(
body
)
{}
:
f_
(
f
),
args_
(
args
),
body_
(
body
)
{}
...
...
src/pass/ir_mutator.cc
View file @
1400edac
...
@@ -180,6 +180,31 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
...
@@ -180,6 +180,31 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
}
}
}
}
Stmt
IRMutator
::
Mutate_
(
const
Prefetch
*
op
,
const
Stmt
&
s
)
{
IRMutator
*
m
=
this
;
Halide
::
Internal
::
Region
new_bounds
;
bool
bounds_changed
=
false
;
// Mutate the bounds
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
Expr
old_min
=
op
->
bounds
[
i
]
->
min
;
Expr
old_extent
=
op
->
bounds
[
i
]
->
extent
;
Expr
new_min
=
m
->
Mutate
(
old_min
);
Expr
new_extent
=
m
->
Mutate
(
old_extent
);
if
(
!
new_min
.
same_as
(
old_min
))
bounds_changed
=
true
;
if
(
!
new_extent
.
same_as
(
old_extent
))
bounds_changed
=
true
;
new_bounds
.
push_back
(
Range
::
make_by_min_extent
(
new_min
,
new_extent
));
}
if
(
!
bounds_changed
)
{
return
s
;
}
else
{
return
Prefetch
::
make
(
op
->
func
,
op
->
value_index
,
op
->
type
,
new_bounds
);
}
}
Stmt
IRMutator
::
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
)
{
Stmt
IRMutator
::
Mutate_
(
const
Block
*
op
,
const
Stmt
&
s
)
{
Stmt
first
=
this
->
Mutate
(
op
->
first
);
Stmt
first
=
this
->
Mutate
(
op
->
first
);
Stmt
rest
=
this
->
Mutate
(
op
->
rest
);
Stmt
rest
=
this
->
Mutate
(
op
->
rest
);
...
...
src/pass/ir_visitor.cc
View file @
1400edac
...
@@ -174,7 +174,6 @@ void IRVisitor::Visit_(const Provide *op) {
...
@@ -174,7 +174,6 @@ void IRVisitor::Visit_(const Provide *op) {
}
}
void
IRVisitor
::
Visit_
(
const
Realize
*
op
)
{
void
IRVisitor
::
Visit_
(
const
Realize
*
op
)
{
// Mutate the bounds
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
this
->
Visit
(
op
->
bounds
[
i
]
->
min
);
this
->
Visit
(
op
->
bounds
[
i
]
->
min
);
this
->
Visit
(
op
->
bounds
[
i
]
->
extent
);
this
->
Visit
(
op
->
bounds
[
i
]
->
extent
);
...
@@ -184,6 +183,13 @@ void IRVisitor::Visit_(const Realize *op) {
...
@@ -184,6 +183,13 @@ void IRVisitor::Visit_(const Realize *op) {
this
->
Visit
(
op
->
condition
);
this
->
Visit
(
op
->
condition
);
}
}
void
IRVisitor
::
Visit_
(
const
Prefetch
*
op
)
{
for
(
size_t
i
=
0
;
i
<
op
->
bounds
.
size
();
i
++
)
{
this
->
Visit
(
op
->
bounds
[
i
]
->
min
);
this
->
Visit
(
op
->
bounds
[
i
]
->
extent
);
}
}
void
IRVisitor
::
Visit_
(
const
Block
*
op
)
{
void
IRVisitor
::
Visit_
(
const
Block
*
op
)
{
this
->
Visit
(
op
->
first
);
this
->
Visit
(
op
->
first
);
this
->
Visit
(
op
->
rest
);
this
->
Visit
(
op
->
rest
);
...
...
src/pass/loop_partition.cc
View file @
1400edac
...
@@ -42,7 +42,7 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
...
@@ -42,7 +42,7 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
// Rule:
// Rule:
// - the range should not be const
// - the range should not be const
// - there exist a condition expression in the scope that use the var
// - there exist a condition expression in the scope that use the var
class
CandidateSelector
:
public
IRVisitor
{
class
CandidateSelector
final
:
public
IRVisitor
{
public
:
public
:
using
VarIsUsed
=
bool
;
using
VarIsUsed
=
bool
;
CandidateSelector
()
{}
CandidateSelector
()
{}
...
...
src/pass/lower_thread_allreduce.cc
View file @
1400edac
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
class
ThreadAllreduceBuilder
:
public
IRMutator
{
class
ThreadAllreduceBuilder
final
:
public
IRMutator
{
public
:
public
:
explicit
ThreadAllreduceBuilder
(
int
warp_size
)
explicit
ThreadAllreduceBuilder
(
int
warp_size
)
:
warp_size_
(
warp_size
)
{}
:
warp_size_
(
warp_size
)
{}
...
...
src/pass/storage_rewrite.cc
View file @
1400edac
...
@@ -31,7 +31,7 @@ using namespace storage;
...
@@ -31,7 +31,7 @@ using namespace storage;
// The storage need to be kept alive between allocate and last access.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
// The free point is only inserted at the same scope of allocate.
//
//
class
StorageAccessPatternFinder
:
public
IRVisitor
{
class
StorageAccessPatternFinder
final
:
public
IRVisitor
{
public
:
public
:
// Get linear access pattern.
// Get linear access pattern.
std
::
vector
<
StmtEntry
>
GetLinearSeq
(
const
Stmt
&
s
)
{
std
::
vector
<
StmtEntry
>
GetLinearSeq
(
const
Stmt
&
s
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment