Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a2c8a29b
Commit
a2c8a29b
authored
Feb 02, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 02, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Improve bound inference, support reduce codegen. (#30)
parent
d4af7ad6
Hide whitespace changes
Inline
Side-by-side
Showing
31 changed files
with
1039 additions
and
438 deletions
+1039
-438
include/tvm/expr.h
+12
-9
include/tvm/ir.h
+3
-3
include/tvm/ir_pass.h
+10
-11
include/tvm/operation.h
+3
-0
include/tvm/schedule.h
+37
-0
include/tvm/schedule_pass.h
+9
-0
python/tvm/api.py
+18
-18
python/tvm/build.py
+4
-2
python/tvm/schedule.py
+8
-0
src/api/api_lang.cc
+6
-0
src/api/api_pass.cc
+0
-1
src/api/api_schedule.cc
+1
-0
src/codegen/codegen_c.cc
+3
-2
src/lang/ir.cc
+5
-5
src/lang/operation.cc
+10
-1
src/pass/ir_mutator.cc
+4
-4
src/pass/ir_visitor.cc
+1
-1
src/pass/simple_passes.cc
+22
-0
src/schedule/bound.cc
+53
-11
src/schedule/compute_expr.h
+109
-0
src/schedule/int_set.cc
+377
-216
src/schedule/int_set.h
+64
-17
src/schedule/schedule_lang.cc
+45
-2
src/schedule/schedule_ops.cc
+180
-126
tests/python/integration/test_ewise.py
+2
-1
tests/python/integration/test_reduce.py
+45
-0
tests/python/unittest/test_codegen_device.py
+1
-2
tests/python/unittest/test_codegen_makeapi.py
+2
-1
tests/python/unittest/test_lang_tensor.py
+1
-1
tests/python/unittest/test_pass_storage_flatten.py
+1
-1
tests/python/unittest/test_schedule_schedule_ops.py
+3
-3
No files found.
include/tvm/expr.h
View file @
a2c8a29b
...
...
@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using
Halide
::
Internal
::
Variable
;
using
Halide
::
Internal
::
make_const
;
using
Halide
::
Internal
::
make_zero
;
using
Halide
::
Internal
::
as_const_int
;
using
Halide
::
Internal
::
as_const_uint
;
inline
Type
TVMType2Type
(
TVMType
t
)
{
...
...
@@ -126,25 +129,25 @@ using Halide::abs;
using
Halide
::
select
;
/*!
* \brief sum of of source expression over
rdom
* \brief sum of of source expression over
axis
* \param source The source expression.
* \param
rdom
List of iteration variables that will be used for reduction.
* \param
axis
List of iteration variables that will be used for reduction.
*/
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
rdom
);
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
axis
);
/*!
* \brief max of of source expression over
rdom
* \brief max of of source expression over
axis
* \param source The source expression.
* \param
rdom
List of iteration variables that will be used for reduction.
* \param
axis
List of iteration variables that will be used for reduction.
*/
Expr
max
(
Expr
source
,
Array
<
IterVar
>
rdom
);
Expr
max
(
Expr
source
,
Array
<
IterVar
>
axis
);
/*!
* \brief max of of source expression over
rdom
* \brief max of of source expression over
axis
* \param source The source expression.
* \param
rdom
List of iteration variables that will be used for reduction.
* \param
axis
List of iteration variables that will be used for reduction.
*/
Expr
min
(
Expr
source
,
Array
<
IterVar
>
rdom
);
Expr
min
(
Expr
source
,
Array
<
IterVar
>
axis
);
// print functions for expr
...
...
include/tvm/ir.h
View file @
a2c8a29b
...
...
@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std
::
string
op
;
/*! \brief The source operand */
Expr
source
;
/*! \brief The reduction
domain
s */
Array
<
IterVar
>
rdom
;
/*! \brief The reduction
axi
s */
Array
<
IterVar
>
axis
;
/*! \brief construct expr from op and rdom */
static
Expr
make
(
std
::
string
op
,
Expr
src
,
Array
<
IterVar
>
rdom
);
...
...
@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v
->
Visit
(
"dtype"
,
&
type
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"source"
,
&
source
);
v
->
Visit
(
"
rdom"
,
&
rdom
);
v
->
Visit
(
"
axis"
,
&
axis
);
}
static
const
IRNodeType
_type_info
=
IRNodeType
::
ExtensionExpr
;
static
constexpr
const
char
*
_type_key
=
"Reduce"
;
...
...
include/tvm/ir_pass.h
View file @
a2c8a29b
...
...
@@ -3,8 +3,8 @@
* \file ir_pass.h
* \brief Collection of IR pass functions
*
*
All
the pass functions in this file are for Stmt,
*
W
e can use PassFunction(Evaluate(expr)) to apply it to Expr
*
When
the pass functions in this file are for Stmt,
*
w
e can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
...
...
@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
...
...
@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e);
Stmt
ConvertSSA
(
Stmt
stmt
);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
IterVar
,
Expr
>&
value_map
);
/*!
* \brief inline all calls of f in stmt.
*
* \param f The function reference to be inlined
...
...
include/tvm/operation.h
View file @
a2c8a29b
...
...
@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public
:
/*! \brief IterVar on each axis */
Array
<
IterVar
>
axis
;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array
<
IterVar
>
reduce_axis
;
/*! \brief the compute expression */
Expr
body
;
/*! \brief constructor */
...
...
@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"axis"
,
&
axis
);
v
->
Visit
(
"reduce_axis"
,
&
reduce_axis
);
v
->
Visit
(
"body"
,
&
body
);
}
static
Operation
make
(
std
::
string
name
,
...
...
include/tvm/schedule.h
View file @
a2c8a29b
...
...
@@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
Expr
x_factor
,
Expr
y_factor
);
// declare container type
using
ContainerType
=
StageNode
;
};
/*!
...
...
@@ -153,10 +155,21 @@ class Schedule : public NodeRef {
return
this
->
operator
[](
tensor
->
op
);
}
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void
normalize
();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
ScheduleNode
*
operator
->
()
const
;
// declare container type
using
ContainerType
=
ScheduleNode
;
};
/*!
...
...
@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO
(
FuseNode
);
};
/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class
RebaseNode
:
public
IterVarRelationNode
{
public
:
/*! \brief The parent domain */
IterVar
parent
;
/*! \brief The inner domain */
IterVar
rebased
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"parent"
,
&
parent
);
v
->
Visit
(
"rebased"
,
&
rebased
);
}
static
IterVarRelation
make
(
IterVar
parent
,
IterVar
rebased
);
static
constexpr
const
char
*
_type_key
=
"Rebase"
;
TVM_DECLARE_NODE_TYPE_INFO
(
RebaseNode
);
};
// implementations
inline
const
StageNode
*
Stage
::
operator
->
()
const
{
return
static_cast
<
const
StageNode
*>
(
node_
.
get
());
...
...
include/tvm/schedule_pass.h
View file @
a2c8a29b
...
...
@@ -24,6 +24,15 @@ namespace schedule {
*/
Map
<
IterVar
,
Range
>
InferBound
(
Schedule
sch
);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
);
}
// namespace schedule
}
// namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
python/tvm/api.py
View file @
a2c8a29b
...
...
@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return
_api_internal
.
_IterVar
(
dom
,
name
,
thread_tag
)
def
sum
(
expr
,
rdom
):
"""Create a sum expression over
rdom
def
sum
(
expr
,
axis
):
"""Create a sum expression over
axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction
domainx
axis : IterVar
The reduction
IterVar axis
"""
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
x
=
_make
.
Reduce
(
"Add"
,
expr
,
rdom
)
axis
=
axis
if
isinstance
(
axis
,
list
)
else
[
axis
]
x
=
_make
.
Reduce
(
"Add"
,
expr
,
axis
)
return
x
def
min
(
expr
,
rdom
):
"""Create a min expression over
rdom
def
min
(
expr
,
axis
):
"""Create a min expression over
axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction
domainx
axis : IterVar
The reduction
IterVar axis
"""
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
x
=
_make
.
Reduce
(
"Min"
,
expr
,
rdom
)
axis
=
axis
if
isinstance
(
axis
,
list
)
else
[
axis
]
x
=
_make
.
Reduce
(
"Min"
,
expr
,
axis
)
return
x
def
max
(
expr
,
rdom
):
"""Create a min expression over
rdom
def
max
(
expr
,
axis
):
"""Create a min expression over
axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction
domainx
axis : IterVar
The reduction
IterVar axis
"""
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
x
=
_make
.
Reduce
(
"Max"
,
expr
,
rdom
)
axis
=
axis
if
isinstance
(
axis
,
list
)
else
[
axis
]
x
=
_make
.
Reduce
(
"Max"
,
expr
,
axis
)
return
x
...
...
python/tvm/build.py
View file @
a2c8a29b
...
...
@@ -62,9 +62,10 @@ def build(sch,
# lowering
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
ir_pass
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
print
(
stmt
)
fapi
=
codegen
.
MakeAPI
(
stmt
,
name
,
arg_list
,
len
(
arg_list
))
fsplits
=
codegen
.
SplitHostDevice
(
fapi
)
...
...
@@ -73,7 +74,8 @@ def build(sch,
for
i
,
f
in
enumerate
(
fsplits
):
t
=
target
if
i
>=
1
else
"c"
record_codes
.
append
(
codegen
.
CompileToC
(
f
,
output_ssa
,
t
))
for
c
in
record_codes
:
print
(
c
)
if
target
==
"cuda"
:
ret
=
codegen
.
BuildNVRTC
(
fsplits
,
"stackvm"
)
elif
target
==
"opencl"
:
...
...
python/tvm/schedule.py
View file @
a2c8a29b
...
...
@@ -33,6 +33,14 @@ class Schedule(NodeBase):
raise
ValueError
(
"Cannot find the operation
%
s in schedule"
%
(
str
(
k
)))
return
self
.
stage_map
[
k
]
def
normalize
(
self
):
"""Build a normalized schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal
.
_ScheduleNormalize
(
self
)
@register_node
class
Stage
(
NodeBase
):
"""A Stage represents schedule for one operation."""
...
...
src/api/api_lang.cc
View file @
a2c8a29b
...
...
@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
TVM_REGISTER_API
(
_ScheduleNormalize
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Schedule
()
.
normalize
();
});
}
// namespace tvm
src/api/api_pass.cc
View file @
a2c8a29b
...
...
@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1
(
ConvertSSA
);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS2
(
ScheduleOps
);
REGISTER_PASS2
(
StorageFlatten
);
}
// namespace ir
...
...
src/api/api_schedule.cc
View file @
a2c8a29b
...
...
@@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1
(
InferBound
);
REGISTER_SCHEDULE_PASS1
(
CreateReadGraph
);
REGISTER_SCHEDULE_PASS2
(
PostDFSOrder
);
REGISTER_SCHEDULE_PASS2
(
ScheduleOps
);
}
// namespace schedule
}
// namespace tvm
src/codegen/codegen_c.cc
View file @
a2c8a29b
...
...
@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
*/
#include <iomanip>
#include "./codegen_c.h"
namespace
tvm
{
...
...
@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch
(
op
->
type
.
bits
())
{
case
64
:
case
32
:
{
std
::
ostringstream
temp
;
temp
<<
op
->
value
;
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
type
.
bits
()
==
32
)
temp
<<
'f'
;
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
...
...
@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case
16
:
{
os
<<
'('
;
p
->
PrintType
(
op
->
type
,
os
);
os
<<
')'
<<
op
->
value
<<
'f'
;
os
<<
')'
<<
std
::
scientific
<<
op
->
value
<<
'f'
;
break
;
}
default
:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
type
<<
"
\n
"
;
...
...
src/lang/ir.cc
View file @
a2c8a29b
...
...
@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<<
op
->
op
<<
", "
;
p
->
print
(
op
->
source
);
p
->
stream
<<
",
rdom="
<<
op
->
rdom
<<
")"
;
p
->
stream
<<
",
axis="
<<
op
->
axis
<<
")"
;
});
}
// namespace Internal
...
...
@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace
tvm
{
namespace
ir
{
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
Array
<
IterVar
>
rdom
)
{
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
Array
<
IterVar
>
axis
)
{
auto
n
=
std
::
make_shared
<
Reduce
>
();
CHECK
(
source
.
defined
());
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
++
i
)
{
CHECK
(
rdom
[
i
].
defined
());
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
CHECK
(
axis
[
i
].
defined
());
}
n
->
type
=
source
.
type
();
n
->
source
=
source
;
n
->
op
=
op
;
n
->
rdom
=
rdom
;
n
->
axis
=
axis
;
return
Expr
(
n
);
}
...
...
src/lang/operation.cc
View file @
a2c8a29b
...
...
@@ -4,6 +4,7 @@
*/
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <memory>
namespace
tvm
{
...
...
@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
// ComputeOpNode
Array
<
IterVar
>
ComputeOpNode
::
root_iter_vars
()
const
{
return
axis
;
if
(
reduce_axis
.
size
()
==
0
)
return
axis
;
Array
<
IterVar
>
ret
=
axis
;
for
(
IterVar
iv
:
reduce_axis
)
{
ret
.
push_back
(
iv
);
}
return
ret
;
}
Type
ComputeOpNode
::
output_dtype
(
size_t
i
)
const
{
...
...
@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
n
->
name
=
name
;
n
->
axis
=
axis
;
n
->
body
=
body
;
if
(
n
->
body
->
is_type
<
ir
::
Reduce
>
())
{
n
->
reduce_axis
=
n
->
body
.
as
<
ir
::
Reduce
>
()
->
axis
;
}
return
Operation
(
n
);
}
...
...
src/pass/ir_mutator.cc
View file @
a2c8a29b
...
...
@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
}
}
inline
Array
<
IterVar
>
Mutate
RDom
(
Array
<
IterVar
>
rdom
,
IRMutator
*
m
)
{
inline
Array
<
IterVar
>
Mutate
IterVarArr
(
Array
<
IterVar
>
rdom
,
IRMutator
*
m
)
{
std
::
vector
<
IterVar
>
new_dom
(
rdom
.
size
());
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
i
++
)
{
...
...
@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Array
<
IterVar
>
new_
rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Array
<
IterVar
>
new_
axis
=
MutateIterVarArr
(
op
->
axis
,
m
);
Expr
new_source
=
m
->
Mutate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
if
(
op
->
axis
.
same_as
(
new_axis
)
&&
op
->
source
.
same_as
(
new_source
))
{
return
e
;
}
else
{
return
Reduce
::
make
(
op
->
op
,
new_source
,
new_
rdom
);
return
Reduce
::
make
(
op
->
op
,
new_source
,
new_
axis
);
}
});
...
...
src/pass/ir_visitor.cc
View file @
a2c8a29b
...
...
@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
IRVisitor
*
v
)
{
VisitRDom
(
op
->
rdom
,
v
);
VisitRDom
(
op
->
axis
,
v
);
v
->
Visit
(
op
->
source
);
})
.
set_dispatch
<
IntImm
>
(
NoOp
)
...
...
src/pass/simple_passes.cc
View file @
a2c8a29b
...
...
@@ -5,6 +5,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace
tvm
{
...
...
@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
v
.
Visit
(
e
);
return
v
.
has_side_effect_
;
}
class
IRSubstitue
:
public
IRMutator
{
public
:
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
smap
.
find
(
op
);
if
(
it
!=
smap
.
end
())
{
return
it
->
second
;
}
else
{
return
e
;
}
}
std
::
unordered_map
<
const
Variable
*
,
Expr
>
smap
;
};
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
IterVar
,
Expr
>&
value_map
)
{
IRSubstitue
m
;
for
(
auto
kv
:
value_map
)
{
m
.
smap
[
kv
.
first
->
var
.
get
()]
=
kv
.
second
;
}
return
m
.
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
src/schedule/bound.cc
View file @
a2c8a29b
...
...
@@ -54,6 +54,11 @@ void PassDown(const Stage& s,
const
Range
&
range_inner
=
state
.
at
(
r
->
inner
);
state
[
r
->
fused
]
=
Range
::
make_with_min_extent
(
0
,
range_outer
->
extent
*
range_inner
->
extent
);
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
r
=
rel
.
as
<
RebaseNode
>
();
CHECK
(
state
.
count
(
r
->
parent
));
state
[
r
->
rebased
]
=
Range
::
make_with_min_extent
(
0
,
state
.
at
(
r
->
parent
)
->
extent
);
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -85,6 +90,13 @@ void PassUp(const Stage& s,
&
outer
,
&
inner
);
state
[
r
->
outer
]
=
outer
;
state
[
r
->
inner
]
=
inner
;
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
IntSet
parent
;
const
RebaseNode
*
r
=
rel
.
as
<
RebaseNode
>
();
PassUp
(
r
,
dom_map
,
state
.
at
(
r
->
rebased
),
&
parent
);
state
[
r
->
parent
]
=
parent
;
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -109,9 +121,15 @@ void PassToOperation(
// Eventually, we need to change the inference to be a Pull style inference
if
(
tensor
->
op
.
as
<
ComputeOpNode
>
())
{
auto
root_iter_vars
=
tensor
->
op
->
root_iter_vars
();
CHECK_EQ
(
tensor
.
ndim
(),
root_iter_vars
.
size
());
for
(
size_t
i
=
0
;
i
<
tensor
.
ndim
();
++
i
)
{
(
*
result
)[
root_iter_vars
[
i
]].
push_back
(
dim_bounds
[
i
]);
const
ComputeOpNode
*
op
=
tensor
->
op
.
as
<
ComputeOpNode
>
();
CHECK_EQ
(
op
->
axis
.
size
()
+
op
->
reduce_axis
.
size
(),
root_iter_vars
.
size
());
for
(
size_t
i
=
0
;
i
<
op
->
axis
.
size
();
++
i
)
{
(
*
result
)[
op
->
axis
[
i
]].
push_back
(
dim_bounds
[
i
]);
}
// reduction.
for
(
size_t
i
=
0
;
i
<
op
->
reduce_axis
.
size
();
++
i
)
{
(
*
result
)[
op
->
reduce_axis
[
i
]].
push_back
(
IntSet
::
range
(
op
->
reduce_axis
[
i
]
->
dom
));
}
}
else
{
LOG
(
FATAL
)
<<
"unknown operation mode "
<<
tensor
->
op
->
type_key
();
...
...
@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{
"local"
,
2
}
};
static
std
::
unordered_map
<
std
::
string
,
int
>
thread_tag_rank
{
{
"
grid
Idx.x"
,
0
},
{
"
grid
Idx.y"
,
0
},
{
"
grid
Idx.z"
,
0
},
{
"
block
Idx.x"
,
0
},
{
"
block
Idx.y"
,
0
},
{
"
block
Idx.z"
,
0
},
{
"threadIdx.x"
,
1
},
{
"threadIdx.y"
,
1
},
{
"threadIdx.z"
,
1
}
...
...
@@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
(
*
rmap
)[
iv
]
=
iv
->
dom
;
}
}
// get range of all child iter vars.
PassDown
(
stage
,
rmap
);
if
(
stage
->
attach_type
==
kScope
)
{
Stage
parent
=
stage
->
attach_stage
;
...
...
@@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
bool
fix_value
=
true
;
for
(
auto
iv
:
parent
->
leaf_iter_vars
)
{
Range
vrange
=
rmap
->
at
(
iv
);
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
"call schedule.normalize to achieve this."
;
// special optimization to remove trivial loop
if
(
is_one
(
vrange
->
extent
))
{
up_state
[
iv
]
=
IntSet
::
single_point
(
vrange
->
min
);
}
if
(
fix_value
&&
!
ScopeRelax
(
iv
,
stage
->
scope
))
{
up_state
[
iv
]
=
IntSet
::
mak
e_point
(
iv
->
var
);
up_state
[
iv
]
=
IntSet
::
singl
e_point
(
iv
->
var
);
}
else
{
up_state
[
iv
]
=
IntSet
::
make_range
(
rmap
->
at
(
iv
)
);
up_state
[
iv
]
=
IntSet
::
range
(
vrange
);
}
if
(
stage
->
attach_ivar
==
iv
)
{
fix_value
=
false
;
...
...
@@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
bp_state
[
iv
]
=
{
up_state
.
at
(
iv
)};
}
auto
result
=
BoundProp
(
post_order
,
&
bp_state
);
// Set relaxation
Map
<
IterVar
,
IntSet
>
relax_set
;
Stage
s
=
stage
;
while
(
s
->
attach_type
==
kScope
)
{
s
=
s
->
attach_stage
;
for
(
auto
iv
:
s
->
leaf_iter_vars
)
{
if
(
ScopeRelax
(
iv
,
stage
->
scope
))
{
relax_set
.
Set
(
iv
,
IntSet
::
range
(
rmap
->
at
(
iv
)));
}
}
}
for
(
auto
iv
:
stage
->
op
->
root_iter_vars
())
{
CHECK
(
result
.
count
(
iv
));
CHECK
(
!
rmap
->
count
(
iv
));
(
*
rmap
)[
iv
]
=
result
.
at
(
iv
).
GetCoverRange
();
Range
r
=
result
.
at
(
iv
).
cover_range
(
iv
->
dom
);
if
(
relax_set
.
size
()
!=
0
)
{
r
=
EvalSet
(
r
,
relax_set
).
cover_range
(
iv
->
dom
);
}
(
*
rmap
)[
iv
]
=
r
;
}
}
// get range of all child iter vars.
PassDown
(
stage
,
rmap
);
}
...
...
src/schedule/compute_expr.h
0 → 100644
View file @
a2c8a29b
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace
tvm
{
namespace
schedule
{
using
Halide
::
Internal
::
add_would_overflow
;
using
Halide
::
Internal
::
sub_would_overflow
;
using
Halide
::
Internal
::
mul_would_overflow
;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \return The result.
*/
template
<
typename
OP
>
inline
Expr
ComputeExpr
(
Expr
lhs
,
Expr
rhs
)
{
return
OP
::
make
(
lhs
,
rhs
);
}
template
<
typename
T
>
inline
bool
GetConst
(
Expr
e
,
T
*
out
);
template
<>
bool
GetConst
<
int64_t
>
(
Expr
e
,
int64_t
*
out
)
{
if
(
e
.
type
().
is_vector
())
return
false
;
const
int64_t
*
v
=
as_const_int
(
e
);
if
(
v
)
{
*
out
=
*
v
;
return
true
;
}
else
{
return
false
;
}
}
template
<>
bool
GetConst
<
uint64_t
>
(
Expr
e
,
uint64_t
*
out
)
{
if
(
e
.
type
().
is_vector
())
return
false
;
const
uint64_t
*
v
=
as_const_uint
(
e
);
if
(
v
)
{
*
out
=
*
v
;
return
true
;
}
else
{
return
false
;
}
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua + ub); \
} \
template
<>
inline
Expr
ComputeExpr
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
a
))
return
b
;
if
(
is_zero
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
add
,
+
);
return
ir
::
Add
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
sub
,
-
);
return
ir
::
Add
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
if
(
is_one
(
a
))
return
b
;
if
(
is_one
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
mul
,
*
);
return
ir
::
Mul
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
if
(
is_one
(
b
))
return
a
;
return
ir
::
Mul
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
return
Halide
::
Internal
::
Interval
::
make_max
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
return
Halide
::
Internal
::
Interval
::
make_min
(
a
,
b
);
}
}
// namespace schedule
}
// namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
src/schedule/int_set.cc
View file @
a2c8a29b
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.cc
* \file int_set
_impl
.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <pass/Interval.h>
#include "./int_set.h"
#include "./compute_expr.h"
namespace
tvm
{
namespace
schedule
{
using
Halide
::
Internal
::
Interval
;
using
namespace
ir
;
/*! \brief Set of continuous interval */
struct
IntervalSet
:
public
IntSetNode
{
/*! \brief the internal interval*/
Interval
i
;
static
IntSet
make
(
Interval
i
)
{
std
::
shared_ptr
<
IntervalSet
>
n
=
std
::
make_shared
<
IntervalSet
>
();
n
->
i
=
i
;
return
IntSet
(
n
);
}
static
IntSet
make
(
Expr
min
,
Expr
max
)
{
std
::
shared_ptr
<
IntervalSet
>
n
=
std
::
make_shared
<
IntervalSet
>
();
n
->
i
.
min
=
min
;
n
->
i
.
max
=
max
;
return
IntSet
(
n
);
}
static
constexpr
const
char
*
_type_key
=
"IntervalSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntervalSet
);
};
/*!
* \brief Internal node container of int set.
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
class
IntSetNode
:
public
Node
{
public
:
/*! \brief The base range scope */
Range
base
;
/*! \brief additional strided domain */
Array
<
Range
>
domain
;
/*! \brief The stride of each strided domain */
Array
<
Expr
>
stride
;
/*!
* \brief The concrete set,
* used when concrete execution is enabled.
*/
std
::
vector
<
int32_t
>
concrete
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"base"
,
&
base
);
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"stride"
,
&
stride
);
}
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntSetNode
);
struct
StrideSet
:
public
IntSetNode
{
/*! \brief the base inetrval */
Interval
base
;
/*! \brief additional extents in positive number */
Array
<
Expr
>
extents
;
/*! \brief additional strides in positive number */
Array
<
Expr
>
strides
;
static
constexpr
const
char
*
_type_key
=
"StrideSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
StrideSet
);
};
TVM_REGISTER_NODE_TYPE
(
IntSetNode
);
inline
IntSet
IntSet
::
cover_interval
()
const
{
if
((
*
this
).
as
<
IntervalSet
>
())
return
*
this
;
const
StrideSet
*
s
=
(
*
this
).
as
<
StrideSet
>
();
if
(
s
)
{
CHECK_NE
(
s
->
extents
.
size
(),
0U
);
Expr
max
=
s
->
base
.
max
;
for
(
size_t
i
=
0
;
i
<
s
->
extents
.
size
();
++
i
)
{
max
=
max
+
s
->
extents
[
i
]
*
s
->
strides
[
i
]
-
s
->
strides
[
i
];
}
return
IntervalSet
::
make
(
s
->
base
.
min
,
max
);
}
LOG
(
FATAL
)
<<
"cannot convert set "
<<
(
*
this
)
->
type_key
()
<<
" to interval"
;
return
IntSet
::
everything
();
}
Range
IntSet
::
cover_range
(
Range
max_range
)
const
{
IntSet
temp
;
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
if
(
s_int
==
nullptr
)
{
temp
=
this
->
cover_interval
();
s_int
=
temp
.
as
<
IntervalSet
>
();
}
if
(
s_int
->
i
.
is_bounded
())
{
return
Range
::
make_with_min_extent
(
s_int
->
i
.
min
,
Simplify
(
s_int
->
i
.
max
+
1
-
s_int
->
i
.
min
));
}
return
max_range
;
}
namespace
{
bool
IntSet
::
is_everything
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_everything
());
}
inline
bool
Match
(
const
Expr
&
e
,
int64_t
value
)
{
const
ir
::
IntImm
*
v
=
e
.
as
<
ir
::
IntImm
>
();
return
v
!=
nullptr
&&
v
->
value
;
bool
IntSet
::
is_single_point
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_single_point
())
;
}
// whether a exactly matches b.
inline
bool
Match
(
const
IntSet
&
a
,
const
Range
&
b
)
{
if
(
a
->
base
==
b
&&
a
->
domain
.
size
()
==
0
&&
a
->
concrete
.
size
()
==
0
)
{
return
true
;
}
else
{
return
false
;
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
make
(
Interval
::
everything
());
}
// whether a exactly matches b.
inline
bool
Match
(
const
IntSet
&
a
,
const
Expr
&
b
)
{
if
(
a
->
domain
.
size
()
==
0
&&
a
->
concrete
.
size
()
==
0
)
{
return
Match
(
a
->
base
->
extent
,
1
)
&&
a
->
base
->
min
.
same_as
(
b
);
}
else
{
return
false
;
}
IntSet
IntSet
::
single_point
(
Expr
x
)
{
return
IntervalSet
::
make
(
Interval
::
single_point
(
x
));
}
inline
bool
IsNumber
(
const
IntSet
&
s
)
{
if
(
s
->
domain
.
size
()
!=
0
)
return
false
;
if
(
s
->
concrete
.
size
()
!=
0
)
{
return
s
->
concrete
.
size
()
==
1
;
IntSet
IntSet
::
range
(
Range
r
)
{
// must make sure it can be matched back by MatchRange.
if
(
is_one
(
r
->
extent
))
{
return
IntSet
::
single_point
(
r
->
min
);
}
if
(
is_positive_const
(
r
->
extent
)
&&
is_const
(
r
->
min
))
{
return
IntervalSet
::
make
(
r
->
min
,
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
r
->
extent
,
r
->
min
),
1
));
}
return
Match
(
s
->
base
->
extent
,
1
);
return
IntervalSet
::
make
(
r
->
min
,
(
r
->
extent
+
r
->
min
)
-
1
);
}
inline
Expr
AsNumber
(
const
IntSet
&
s
)
{
return
s
->
base
->
min
;
// Check if a is created from b.
inline
bool
MatchRange
(
const
IntSet
&
a
,
const
Range
&
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
if
(
!
i
.
min
.
same_as
(
b
))
return
false
;
if
(
is_one
(
b
->
extent
))
return
i
.
is_single_point
();
if
(
is_positive_const
(
b
->
extent
)
&&
is_const
(
b
->
min
))
{
// deep equality
return
Equal
(
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
b
->
extent
,
b
->
min
),
1
),
a_int
->
i
.
max
);
}
const
Sub
*
sub
=
i
.
max
.
as
<
Sub
>
();
if
(
!
sub
)
return
false
;
if
(
is_one
(
sub
->
b
))
return
false
;
const
Add
*
add
=
sub
->
a
.
as
<
Add
>
();
return
add
&&
add
->
a
.
same_as
(
b
->
min
)
&&
add
->
b
.
same_as
(
b
->
extent
);
}
// set combination rule by operators
template
<
typename
T
>
inline
IntSet
BinaryCombine
(
IntSet
a
,
IntSet
b
)
{
LOG
(
WARNING
)
<<
"cannot evaluate binary op "
<<
T
::
_type_key
;
return
IntSet
::
make_all_set
();
inline
bool
MatchPoint
(
const
IntSet
&
a
,
const
Expr
&
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
return
i
.
is_single_point
()
&&
i
.
min
.
same_as
(
b
);
}
template
<>
inline
IntSet
BinaryCombine
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
auto
n
=
std
::
make_shared
<
IntSetNode
>
(
*
(
a
.
operator
->
()));
for
(
size_t
i
=
0
;
i
<
b
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
b
->
domain
[
i
]);
n
->
stride
.
push_back
(
b
->
stride
[
i
]);
}
if
(
IsNumber
(
a
))
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
b
->
base
->
extent
);
}
else
if
(
IsNumber
(
b
))
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
a
->
base
->
extent
);
}
else
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
a
->
base
->
extent
+
b
->
base
->
extent
-
1
);
IntSet
Union
(
const
Array
<
IntSet
>&
set
)
{
if
(
set
.
size
()
==
1
)
return
set
[
0
];
Interval
x
=
set
[
0
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
for
(
size_t
i
=
1
;
i
<
set
.
size
();
++
i
)
{
x
.
include
(
set
[
i
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
);
}
return
Int
Set
(
n
);
return
Int
ervalSet
::
make
(
x
);
}
inline
Range
Negation
(
Range
a
)
{
if
(
Match
(
a
->
extent
,
1
))
{
return
Range
::
make_with_min_extent
(
-
a
->
min
,
a
->
extent
);
}
else
{
return
Range
::
make_with_min_extent
(
-
(
a
->
min
+
a
->
extent
-
1
),
a
->
extent
);
// type traits
template
<
typename
OP
>
struct
is_logical_op
{
static
const
bool
value
=
false
;
};
#define TVM_DECLARE_LOGICAL_OP(OP) \
template<> \
struct is_logical_op<ir::OP> { \
static const bool value = true; \
};
// interval related.
template
<
typename
OP
>
inline
IntSet
CombineInterval
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
OP
>
(
a
.
min
,
b
.
min
));
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval "
<<
OP
::
_type_key
;
return
IntSet
::
everything
();
}
inline
IntSet
Negation
(
IntSet
a
)
{
CHECK_EQ
(
a
->
concrete
.
size
(),
0U
);
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
Negation
(
a
->
base
);
for
(
size_t
i
=
0
;
i
<
a
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
Negation
(
a
->
domain
[
i
]));
n
->
stride
.
push_back
(
a
->
stride
[
i
]);
template
<>
inline
IntSet
CombineInterval
<
Add
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
));
}
Interval
r
=
Interval
::
everything
();
if
(
a
.
has_lower_bound
()
&&
b
.
has_lower_bound
())
{
r
.
min
=
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
);
}
return
IntSet
(
a
);
if
(
a
.
has_upper_bound
()
&&
b
.
has_upper_bound
())
{
r
.
max
=
ComputeExpr
<
Add
>
(
a
.
max
,
b
.
max
);
}
return
IntervalSet
::
make
(
r
);
}
template
<>
inline
IntSet
BinaryCombine
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
return
BinaryCombine
<
Add
>
(
a
,
Negation
(
b
));
inline
IntSet
CombineInterval
<
Sub
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
min
));
}
Interval
r
=
Interval
::
everything
();
if
(
a
.
has_lower_bound
()
&&
b
.
has_upper_bound
())
{
r
.
min
=
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
max
);
}
if
(
a
.
has_upper_bound
()
&&
b
.
has_lower_bound
())
{
r
.
max
=
ComputeExpr
<
Sub
>
(
a
.
max
,
b
.
min
);
}
return
IntervalSet
::
make
(
r
);
}
inline
IntSet
BinaryMul
(
IntSet
a
,
Expr
b
)
{
// copy construct
if
(
Match
(
b
,
1
))
return
a
;
if
(
Match
(
b
,
-
1
))
return
Negation
(
a
);
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
Range
::
make_with_min_extent
(
0
,
1
);
n
->
domain
.
push_back
(
a
->
base
);
n
->
stride
.
push_back
(
b
);
for
(
size_t
i
=
0
;
i
<
a
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
a
->
domain
[
i
]);
n
->
stride
.
push_back
(
a
->
stride
[
i
]
*
b
);
}
return
IntSet
(
a
);
template
<>
inline
IntSet
CombineInterval
<
Mul
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
));
}
if
(
a
.
is_single_point
()
&&
!
b
.
is_single_point
())
{
std
::
swap
(
a
,
b
);
}
if
(
b
.
is_single_point
())
{
if
(
is_zero
(
b
.
min
))
return
IntSet
::
single_point
(
0
);
if
(
is_one
(
b
.
min
))
return
IntervalSet
::
make
(
a
);
Expr
e1
=
a
.
has_lower_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
)
:
a
.
min
;
Expr
e2
=
a
.
has_upper_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
max
,
b
.
min
)
:
a
.
max
;
// This is relaxiation
// TODO(tqchen): consider convert to StrideSet.
if
(
is_positive_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e1
,
e2
);
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
select
(
cmp
,
e1
,
e2
),
select
(
cmp
,
e2
,
e1
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
return
IntSet
::
everything
();
}
template
<>
inline
IntSet
BinaryCombine
<
Mul
>
(
IntSet
a
,
IntSet
b
)
{
if
(
IsNumber
(
a
))
{
return
BinaryMul
(
a
,
AsNumber
(
b
));
}
else
if
(
IsNumber
(
b
))
{
return
BinaryMul
(
b
,
AsNumber
(
a
));
}
else
{
return
IntSet
::
make_all_set
();
inline
IntSet
CombineInterval
<
Max
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Max
>
(
a
.
min
,
b
.
min
));
}
return
IntervalSet
::
make
(
Interval
::
make_max
(
a
.
min
,
b
.
min
),
Interval
::
make_max
(
a
.
max
,
b
.
max
));
}
}
// namespace
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
template
<>
inline
IntSet
CombineInterval
<
Min
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Min
>
(
a
.
min
,
b
.
min
));
}
return
IntervalSet
::
make
(
Interval
::
make_min
(
a
.
min
,
b
.
min
),
Interval
::
make_min
(
a
.
max
,
b
.
max
));
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IntSetNode
>
([](
const
IntSetNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"int-set(base="
;
p
->
print
(
op
->
base
);
p
->
stream
<<
')'
;
});
template
<
typename
OP
>
inline
IntSet
CombineInterval_
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval
<
OP
>
(
a
.
as
<
IntervalSet
>
()
->
i
,
b
.
as
<
IntervalSet
>
()
->
i
);
}
IntSet
IntSet
::
make_range
(
Range
dom
)
{
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
dom
;
// stride related
inline
IntSet
AsStrideSet
(
IntSet
a
)
{
if
(
a
.
as
<
StrideSet
>
())
return
a
;
const
IntervalSet
*
s
=
a
.
as
<
IntervalSet
>
();
CHECK
(
s
->
i
.
is_bounded
());
std
::
shared_ptr
<
StrideSet
>
n
=
std
::
make_shared
<
StrideSet
>
();
n
->
base
=
s
->
i
;
return
IntSet
(
n
);
}
template
<
typename
OP
>
inline
IntSet
CombineSets
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
.
cover_interval
());
}
Range
IntSet
::
GetCoverRange
()
const
{
const
IntSetNode
*
s
=
operator
->
();
CHECK
(
s
!=
nullptr
)
<<
"empty set"
;
if
(
s
->
domain
.
size
()
==
0
&&
s
->
concrete
.
size
()
==
0
)
{
return
s
->
base
;
template
<>
inline
IntSet
CombineSets
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
if
(
a_int
&&
is_zero
(
a_int
->
i
.
min
))
return
b
;
if
(
b_int
&&
is_zero
(
b_int
->
i
.
min
))
return
a
;
a
=
AsStrideSet
(
a
);
b
=
AsStrideSet
(
b
);
const
StrideSet
*
a_stride
=
a
.
as
<
StrideSet
>
();
const
StrideSet
*
b_stride
=
b
.
as
<
StrideSet
>
();
auto
n
=
std
::
make_shared
<
StrideSet
>
(
*
a_stride
);
for
(
size_t
i
=
0
;
i
<
b_stride
->
extents
.
size
();
++
i
)
{
n
->
extents
.
push_back
(
b_stride
->
extents
[
i
]);
n
->
strides
.
push_back
(
b_stride
->
strides
[
i
]);
}
LOG
(
FATAL
)
<<
"not yet implemented"
;
return
Range
();
n
->
base
=
CombineInterval
<
Add
>
(
a_stride
->
base
,
b_stride
->
base
).
as
<
IntervalSet
>
()
->
i
;
return
IntSet
(
n
);
}
IntSet
IntSet
::
make_point
(
Expr
point
)
{
return
IntSet
::
make_range
(
Range
::
make_with_min_extent
(
point
,
1
));
inline
IntSet
NegateSet
(
IntSet
a
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
a_int
)
{
if
(
a_int
->
i
.
is_single_point
())
{
return
IntSet
::
single_point
(
-
a_int
->
i
.
min
);
}
else
{
Interval
r
=
Interval
::
everything
();
if
(
a_int
->
i
.
has_upper_bound
())
{
r
.
min
=
-
(
a_int
->
i
.
max
);
}
if
(
a_int
->
i
.
has_lower_bound
())
{
r
.
max
=
-
(
a_int
->
i
.
min
);
}
return
IntervalSet
::
make
(
r
);
}
}
else
{
return
NegateSet
(
a
.
cover_interval
());
}
}
IntSet
IntSet
::
make_all_set
()
{
LOG
(
FATAL
)
<<
"TODO"
;
return
IntSet
(
);
template
<>
inline
IntSet
CombineSets
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
return
CombineSets
<
Add
>
(
a
,
NegateSet
(
b
)
);
}
IntSet
Union
(
const
Array
<
IntSet
>&
set
)
{
if
(
set
.
size
()
==
1
)
return
set
[
0
];
LOG
(
FATAL
)
<<
"TODO"
;
return
IntSet
();
TVM_DECLARE_LOGICAL_OP
(
And
);
TVM_DECLARE_LOGICAL_OP
(
Or
);
TVM_DECLARE_LOGICAL_OP
(
EQ
);
TVM_DECLARE_LOGICAL_OP
(
NE
);
TVM_DECLARE_LOGICAL_OP
(
GE
);
TVM_DECLARE_LOGICAL_OP
(
GT
);
TVM_DECLARE_LOGICAL_OP
(
LE
);
TVM_DECLARE_LOGICAL_OP
(
LT
);
TVM_DECLARE_LOGICAL_OP
(
Not
);
// generic combine operations of two sets
template
<
typename
OP
>
inline
IntSet
Combine
(
const
IntSet
&
a
,
const
IntSet
&
b
)
{
if
(
is_logical_op
<
OP
>::
value
)
{
return
IntervalSet
::
make
(
0
,
1
);
}
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
if
(
a_int
&&
a_int
->
i
.
is_everything
())
return
a
;
if
(
b_int
&&
b_int
->
i
.
is_everything
())
return
b
;
if
(
a_int
&&
b_int
)
{
return
CombineInterval
<
OP
>
(
a_int
->
i
,
b_int
->
i
);
}
if
(
a_int
&&
!
(
a_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
,
b
.
cover_interval
());
}
if
(
b_int
&&
!
(
b_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
);
}
return
CombineSets
<
OP
>
(
a
,
b
);
}
// Implementation of Evaluations and passing.
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
...
...
@@ -215,33 +358,21 @@ void PassUp(const SplitNode* s,
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
Match
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
Match
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
make_
range
(
dom_map
.
at
(
s
->
parent
));
Match
Range
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
Match
Range
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
// copy construct
auto
n
=
std
::
make_shared
<
IntSetNode
>
(
*
(
inner
.
operator
->
()));
if
(
IsNumber
(
outer
))
{
// shift the base offset
n
->
base
=
Range
::
make_with_min_extent
(
AsNumber
(
outer
)
*
factor
+
inner
->
base
->
min
,
inner
->
base
->
extent
);
}
else
{
// default use all domains in the data.
n
->
domain
.
push_back
(
outer
->
base
);
n
->
stride
.
push_back
(
factor
);
for
(
size_t
i
=
0
;
i
<
outer
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
outer
->
domain
[
i
]);
n
->
stride
.
push_back
(
outer
->
stride
[
i
]
*
factor
);
}
}
*
parent
=
IntSet
(
n
);
*
parent
=
Combine
<
Add
>
(
Combine
<
Add
>
(
Combine
<
Mul
>
(
outer
,
IntSet
::
single_point
(
factor
)),
inner
),
IntSet
::
single_point
(
parent_min
));
}
void
PassUp
(
const
FuseNode
*
s
,
...
...
@@ -253,29 +384,51 @@ void PassUp(const FuseNode* s,
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
Match
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
make_
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make_
range
(
dom_map
.
at
(
s
->
inner
));
if
(
Match
Range
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
if
(
IsNumber
(
fused
))
{
Expr
value
=
AsNumber
(
fused
);
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
const
IntervalSet
*
fused_int
=
fused
.
as
<
IntervalSet
>
();
if
(
fused_int
&&
fused_int
->
i
.
is_single_point
())
{
Expr
value
=
fused_int
->
i
.
min
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
*
outer
=
IntSet
::
make_point
(
value
/
factor
);
*
inner
=
IntSet
::
make_point
(
value
%
factor
);
Expr
v_outer
=
value
/
factor
;
Expr
v_inner
=
value
%
factor
;
if
(
!
is_zero
(
outer_min
))
v_outer
=
v_outer
+
outer_min
;
if
(
!
is_zero
(
inner_min
))
v_inner
=
v_inner
+
inner_min
;
*
outer
=
IntSet
::
single_point
(
v_outer
);
*
inner
=
IntSet
::
single_point
(
v_inner
);
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
make_range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make_range
(
dom_map
.
at
(
s
->
inner
));
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
rebased
,
IntSet
*
parent
)
{
CHECK
(
dom_map
.
count
(
s
->
parent
));
if
(
MatchRange
(
rebased
,
dom_map
.
at
(
s
->
rebased
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
Combine
<
Add
>
(
rebased
,
IntSet
::
single_point
(
parent_min
));
}
namespace
{
// evaluator to evaluate the int set
class
IRSetEvaluator
{
// Evaluator to evalute the epxression.
class
IntSetEvaluator
{
public
:
inline
IntSet
Eval
(
Expr
expr
)
{
static
const
FType
&
f
=
vtable
();
...
...
@@ -283,11 +436,11 @@ class IRSetEvaluator {
return
f
(
expr
,
expr
,
this
);
}
else
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
expr
->
type_key
();
return
IntSet
::
make_all_set
();
return
IntSet
::
everything
();
}
}
using
FType
=
tvm
::
IRFunctor
<
IntSet
(
const
NodeRef
&
,
const
Expr
&
,
I
R
SetEvaluator
*
)
>
;
using
FType
=
tvm
::
IRFunctor
<
IntSet
(
const
NodeRef
&
,
const
Expr
&
,
I
nt
SetEvaluator
*
)
>
;
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
inst
;
return
inst
;
}
...
...
@@ -295,76 +448,84 @@ class IRSetEvaluator {
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
};
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
I
R
SetEvaluator
*
)
{
return
IntSet
::
mak
e_point
(
e
);
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
I
nt
SetEvaluator
*
)
{
return
IntSet
::
singl
e_point
(
e
);
}
TVM_STATIC_IR_FUNCTOR
(
I
R
SetEvaluator
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
I
nt
SetEvaluator
,
vtable
)
.
set_dispatch
<
IntImm
>
(
ConstOp
)
.
set_dispatch
<
UIntImm
>
(
ConstOp
)
.
set_dispatch
<
FloatImm
>
(
ConstOp
);
TVM_STATIC_IR_FUNCTOR
(
I
R
SetEvaluator
,
vtable
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
I
R
SetEvaluator
*
m
)
{
TVM_STATIC_IR_FUNCTOR
(
I
nt
SetEvaluator
,
vtable
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
I
nt
SetEvaluator
*
m
)
{
auto
it
=
m
->
dom_map
.
find
(
op
);
if
(
it
!=
m
->
dom_map
.
end
())
{
return
it
->
second
;
}
else
{
return
IntSet
::
mak
e_point
(
e
);
return
IntSet
::
singl
e_point
(
e
);
}
});
// binary operator
template
<
typename
T
>
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
,
I
R
SetEvaluator
*
m
)
{
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
,
I
nt
SetEvaluator
*
m
)
{
IntSet
a
=
m
->
Eval
(
op
->
a
);
IntSet
b
=
m
->
Eval
(
op
->
b
);
if
(
IsNumber
(
a
)
&&
IsNumber
(
b
))
{
if
(
Match
(
a
,
op
->
a
)
&&
Match
(
b
,
op
->
b
))
{
return
IntSet
::
make_point
(
e
);
}
else
{
return
IntSet
::
make_point
(
T
::
make
(
AsNumber
(
a
),
AsNumber
(
b
)));
}
}
else
{
return
BinaryCombine
<
T
>
(
a
,
b
);
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
return
IntSet
::
single_point
(
e
);
}
IntSet
r
=
Combine
<
T
>
(
a
,
b
);
return
r
;
}
TVM_STATIC_IR_FUNCTOR
(
I
R
SetEvaluator
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
I
nt
SetEvaluator
,
vtable
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
);
// use simply bound for logical expressions for now.
inline
IntSet
Logical
(
const
NodeRef
&
,
const
Expr
&
e
,
IRSetEvaluator
*
)
{
return
IntSet
::
make_range
(
Range
::
make_with_min_extent
(
0
,
2
));
}
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
.
set_dispatch
<
EQ
>
(
Logical
)
.
set_dispatch
<
NE
>
(
Logical
)
.
set_dispatch
<
LT
>
(
Logical
)
.
set_dispatch
<
LE
>
(
Logical
)
.
set_dispatch
<
GT
>
(
Logical
)
.
set_dispatch
<
GE
>
(
Logical
)
.
set_dispatch
<
And
>
(
Logical
)
.
set_dispatch
<
Or
>
(
Logical
);
}
// namespace
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
)
.
set_dispatch
<
EQ
>
(
Binary
<
EQ
>
)
.
set_dispatch
<
NE
>
(
Binary
<
NE
>
)
.
set_dispatch
<
LT
>
(
Binary
<
LT
>
)
.
set_dispatch
<
LE
>
(
Binary
<
LE
>
)
.
set_dispatch
<
GT
>
(
Binary
<
GT
>
)
.
set_dispatch
<
GE
>
(
Binary
<
GE
>
)
.
set_dispatch
<
And
>
(
Binary
<
And
>
)
.
set_dispatch
<
Or
>
(
Binary
<
Or
>
);
IntSet
EvalSet
(
Expr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
I
R
SetEvaluator
m
;
I
nt
SetEvaluator
m
;
for
(
auto
kv
:
dom_map
)
{
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
return
m
.
Eval
(
e
);
}
IntSet
EvalSet
(
Range
r
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
IntSetEvaluator
m
;
for
(
auto
kv
:
dom_map
)
{
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
IntSet
min_set
=
m
.
Eval
(
r
->
min
);
IntSet
ext_set
=
m
.
Eval
(
r
->
extent
).
cover_interval
();
const
Interval
&
ei
=
ext_set
.
as
<
IntervalSet
>
()
->
i
;
if
(
!
ei
.
has_upper_bound
())
return
IntSet
::
everything
();
ext_set
=
IntervalSet
::
make
(
0
,
ComputeExpr
<
Sub
>
(
ei
.
max
,
1
));
return
Combine
<
Add
>
(
min_set
,
ext_set
);
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IntervalSet
>
([](
const
IntervalSet
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"interval-set["
<<
"["
<<
op
->
i
.
min
<<
", "
<<
op
->
i
.
max
<<
']'
;
});
}
// namespace schedule
}
// namespace tvm
src/schedule/int_set.h
View file @
a2c8a29b
...
...
@@ -22,35 +22,48 @@ class IntSet : public NodeRef {
public
:
/*! \brief constructor */
IntSet
()
{}
// constructor from not
de
ontainer.
// constructor from not
c
ontainer.
explicit
IntSet
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*! \return whether the set is empty */
inline
bool
is_empty
()
const
{
return
!
defined
();
}
/*!
* \return a range that covers the IntSet
*/
Range
GetCoverRange
()
const
;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
IntSetNode
*
operator
->
()
const
;
/*!
* \param dom The domain to be created.
* \return create integer set from existing domain
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range
cover_range
(
Range
max_range
)
const
;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
static
IntSet
make_range
(
Range
dom
);
IntSet
cover_interval
()
const
;
/*! \return Whether the set represent everything */
bool
is_everything
()
const
;
/*! \return Whether the set is a single point */
bool
is_single_point
()
const
;
/*! \return Whether the set contains everything */
static
IntSet
everything
();
/*!
* \param point
* \return create integer set that only contains one point
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static
IntSet
mak
e_point
(
Expr
point
);
static
IntSet
singl
e_point
(
Expr
point
);
/*!
* \return create integer set that represents everything
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static
IntSet
make_all_set
();
static
IntSet
range
(
Range
r
);
};
/*!
* \brief Base class of all IntSet containers.
*/
struct
IntSetNode
:
public
Node
{
};
/*!
...
...
@@ -63,6 +76,18 @@ class IntSet : public NodeRef {
*/
IntSet
EvalSet
(
Expr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet
EvalSet
(
Range
r
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Conditional upward message passing.
*
...
...
@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
parent
);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
...
...
@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
*/
IntSet
Union
(
const
Array
<
IntSet
>&
sets
);
// implementation
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
}
}
// namespace schedule
}
// namespace tvm
...
...
src/schedule/schedule_lang.cc
View file @
a2c8a29b
...
...
@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
CHECK
(
found
)
<<
"Cannot
compute at a iteration variable that is not part of parent leaf
vars"
;
<<
"Cannot
find the specified axis in parent stage's leaf_iter_
vars"
;
return
*
this
;
}
...
...
@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return
*
this
;
}
Schedule
::
Schedule
(
Array
<
Operation
>
ops
)
{
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
roots
=
ops
;
...
...
@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
return
IterVarRelation
(
n
);
}
IterVarRelation
RebaseNode
::
make
(
IterVar
parent
,
IterVar
rebased
)
{
auto
n
=
std
::
make_shared
<
RebaseNode
>
();
n
->
parent
=
parent
;
n
->
rebased
=
rebased
;
return
IterVarRelation
(
n
);
}
void
Schedule
::
normalize
()
{
std
::
unordered_map
<
IterVar
,
IterVar
>
rebase_map
;
std
::
unordered_map
<
const
Node
*
,
int
>
attach_mark
;
for
(
Stage
s
:
(
*
this
)
->
stages
)
{
if
(
s
->
attach_type
==
kScope
)
{
attach_mark
[
s
->
attach_stage
.
get
()]
=
1
;
}
}
for
(
Stage
s
:
(
*
this
)
->
stages
)
{
if
(
!
attach_mark
.
count
(
s
.
get
()))
continue
;
auto
root_iter_vars
=
s
->
op
->
root_iter_vars
();
ArrayNode
*
leaf_vars
=
s
->
leaf_iter_vars
.
CopyOnWrite
();
for
(
IterVar
iv
:
root_iter_vars
)
{
size_t
idx
=
FindIterVar
(
leaf_vars
,
iv
);
if
(
idx
<
leaf_vars
->
data
.
size
())
{
// insert rebase
IterVar
rebased
(
Range
(),
iv
->
var
->
name_hint
+
".rb"
);
s
->
relations
.
push_back
(
RebaseNode
::
make
(
iv
,
rebased
));
leaf_vars
->
data
[
idx
]
=
rebased
.
node_
;
rebase_map
[
iv
]
=
rebased
;
}
}
}
// remap the parent relation
for
(
Stage
s
:
(
*
this
)
->
stages
)
{
if
(
s
->
attach_type
!=
kScope
)
continue
;
if
(
rebase_map
.
count
(
s
->
attach_ivar
))
{
s
->
attach_ivar
=
rebase_map
.
at
(
s
->
attach_ivar
);
}
}
}
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
RebaseNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
}
// namespace tvm
src/
pass
/schedule_ops.cc
→
src/
schedule
/schedule_ops.cc
View file @
a2c8a29b
...
...
@@ -8,12 +8,44 @@
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include ".
/scope
.h"
#include "./i
r_util
.h"
#include ".
./schedule
/graph.h"
#include ".
./pass/ir_util
.h"
#include "./i
nt_set
.h"
#include "./graph.h"
namespace
tvm
{
namespace
ir
{
namespace
schedule
{
using
namespace
ir
;
/*!
* \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void
PassDownFlag
(
const
Stage
&
s
,
std
::
unordered_map
<
IterVar
,
int
>*
p_state
)
{
auto
&
state
=
*
p_state
;
for
(
IterVarRelation
rel
:
s
->
relations
)
{
if
(
rel
.
as
<
SplitNode
>
())
{
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
();
int
flag
=
state
.
at
(
s
->
parent
);
state
[
s
->
outer
]
=
flag
;
state
[
s
->
inner
]
=
flag
;
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
int
flag_outer
=
state
.
at
(
s
->
outer
);
int
flag_inner
=
state
.
at
(
s
->
inner
);
state
[
s
->
fused
]
=
flag_outer
|
flag_inner
;
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
int
flag
=
state
.
at
(
s
->
parent
);
state
[
s
->
rebased
]
=
flag
;
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
}
}
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
...
...
@@ -37,7 +69,7 @@ void PassUpOffset(const Stage& s,
state
[
s
->
parent
]
=
inner
+
outer
*
factor
;
// add min if they exist
if
(
!
is_zero
(
parent_min
))
{
state
[
s
->
parent
]
=
parent_min
+
state
[
s
->
parent
]
;
state
[
s
->
parent
]
=
state
[
s
->
parent
]
+
parent_min
;
}
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
...
...
@@ -49,10 +81,20 @@ void PassUpOffset(const Stage& s,
state
[
s
->
inner
]
=
value
%
factor
;
// add min if they exist
if
(
!
is_zero
(
outer_min
))
{
state
[
s
->
outer
]
=
outer_min
+
state
[
s
->
outer
]
;
state
[
s
->
outer
]
=
state
[
s
->
outer
]
+
outer_min
;
}
if
(
!
is_zero
(
inner_min
))
{
state
[
s
->
inner
]
=
outer_min
+
state
[
s
->
inner
];
state
[
s
->
inner
]
=
state
[
s
->
inner
]
+
inner_min
;
}
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
Expr
value
=
state
.
at
(
s
->
rebased
);
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
// add min if they exist
if
(
!
is_zero
(
parent_min
))
{
state
[
s
->
parent
]
=
value
+
parent_min
;
}
else
{
state
[
s
->
parent
]
=
value
;
}
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
...
...
@@ -60,76 +102,54 @@ void PassUpOffset(const Stage& s,
}
}
/*!
* \brief split the expr by addition.
* \param expr The expression to be splitted.
* \param loop_level The loop level of each Variable
* \param result vector of (level, expr)
* The level gives the mimimum loop level this expression need to be computed.
* The Expr gives the expression content.
*/
void
SplitByAdd
(
Expr
expr
,
const
std
::
unordered_map
<
const
Variable
*
,
size_t
>&
loop_level
,
std
::
vector
<
std
::
pair
<
size_t
,
Expr
>
>
*
result
)
{
const
Add
*
op
=
expr
.
as
<
Add
>
();
if
(
op
!=
nullptr
)
{
SplitByAdd
(
op
->
a
,
loop_level
,
result
);
SplitByAdd
(
op
->
b
,
loop_level
,
result
);
}
else
{
size_t
max_level
=
0
;
auto
fvisit
=
[
&
max_level
,
&
loop_level
](
const
NodeRef
&
n
)
{
const
Variable
*
op
=
n
.
as
<
Variable
>
();
if
(
op
!=
nullptr
)
{
auto
it
=
loop_level
.
find
(
op
);
if
(
it
!=
loop_level
.
end
())
{
max_level
=
std
::
max
(
max_level
,
it
->
second
);
}
}
};
PostOrderVisit
(
expr
,
fvisit
);
result
->
push_back
(
std
::
make_pair
(
max_level
,
expr
));
}
}
/*!
* \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule.
* \param dom_map The domain map.
*
* \return a nested representation of loop statements.
* The flattened Stmt are ordered from outmost to inner most order.
*/
std
::
vector
<
std
::
vector
<
Stmt
>
>
MakeLoopNest
(
const
Stage
&
sch
,
const
Map
<
IterVar
,
Range
>&
dom_map
)
{
// optional, use let to define some CSE in dom_map.
std
::
vector
<
std
::
vector
<
Stmt
>
>
MakeLoopNest
(
const
Stage
&
sch
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
size_t
begin_loop
,
bool
reduce_init_loop
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
const
std
::
unordered_map
<
IterVar
,
bool
>&
skip_iter
)
{
auto
leaf_iter_vars
=
sch
->
leaf_iter_vars
;
std
::
unordered_map
<
IterVar
,
Expr
>
offset
;
std
::
unordered_map
<
const
Variable
*
,
size_t
>
loop_level
;
Stmt
no_op
=
Evaluate
::
make
(
0
);
// create the loop nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
nest
;
nest
.
resize
(
leaf_iter_vars
.
size
()
+
1
);
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
=
*
p_value_map
;
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
for
(
size_t
i
=
begin_loop
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
if
(
skip_iter
.
count
(
iv
)
&&
skip_iter
.
at
(
iv
))
{
// skip this iteration.
value_map
[
iv
]
=
iv
->
var
;
continue
;
}
Range
dom
=
dom_map
.
at
(
iv
);
// initialize the offset and loop_level
offset
[
iv
]
=
iv
->
var
;
loop_level
[
iv
->
var
.
as
<
Variable
>
()]
=
i
+
1
;
Var
var
=
iv
->
var
;
if
(
reduce_init_loop
)
{
var
=
Var
(
iv
->
var
->
name_hint
+
".init"
,
iv
->
var
.
type
());
}
// Mark the iter var in the IR, to remember the point
if
(
iv
->
thread_tag
.
length
()
==
0
)
{
if
(
is_zero
(
dom
->
min
))
{
if
(
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
value_map
[
iv
]
=
dom
->
min
;
}
else
if
(
is_zero
(
dom
->
min
))
{
nest
[
i
+
1
].
emplace_back
(
For
::
make
(
iv
->
var
,
0
,
dom
->
extent
,
For
::
make
(
var
,
0
,
dom
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
value_map
[
iv
]
=
var
;
}
else
{
Var
idx
(
iv
->
var
->
name_hint
+
".idx"
,
iv
->
var
.
type
());
nest
[
i
+
1
].
emplace_back
(
For
::
make
(
idx
,
0
,
dom
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
Expr
new_value
=
dom
->
min
+
idx
;
value_map
[
iv
]
=
new_value
;
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
iv
->
var
,
dom
->
min
+
idx
,
no_op
));
LetStmt
::
make
(
var
,
new_value
,
no_op
));
}
}
else
{
// Always restrict threaded IterVar to starts from 0.
...
...
@@ -137,69 +157,73 @@ std::vector<std::vector<Stmt> > MakeLoopNest(
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"thread_extent"
,
dom
->
extent
,
no_op
));
value_map
[
iv
]
=
var
;
}
if
(
!
reduce_init_loop
)
{
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"scope"
,
iv
->
var
,
no_op
));
}
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"scope"
,
iv
->
var
,
no_op
));
}
// message passing to get offset of root iter vars.
PassUpOffset
(
sch
,
dom_map
,
&
offset
);
for
(
IterVar
iv
:
sch
->
op
->
root_iter_vars
())
{
Expr
value
=
offset
.
at
(
iv
);
if
(
!
value
.
same_as
(
iv
->
var
))
{
using
Entry
=
std
::
pair
<
size_t
,
Expr
>
;
std
::
vector
<
Entry
>
splits
;
SplitByAdd
(
value
,
loop_level
,
&
splits
);
PassUpOffset
(
sch
,
dom_map
,
&
value_map
);
return
nest
;
}
Expr
offset
=
0
;
size_t
nsplit_left
=
splits
.
size
()
-
1
;
for
(
size_t
i
=
0
;
i
<=
leaf_iter_vars
.
size
();
++
i
)
{
size_t
hit
=
0
;
for
(
const
auto
&
kv
:
splits
)
{
if
(
kv
.
first
==
i
)
{
if
(
is_zero
(
offset
))
{
offset
=
kv
.
second
;
}
else
{
offset
=
offset
+
kv
.
second
;
++
hit
;
}
}
}
nsplit_left
-=
hit
;
if
(
hit
!=
0
)
{
std
::
ostringstream
os
;
os
<<
iv
->
var
->
name_hint
<<
".at.l"
<<
i
;
Var
base_offset
(
os
.
str
());
if
(
nsplit_left
==
0
)
{
base_offset
=
iv
->
var
;
}
nest
[
i
].
emplace_back
(
LetStmt
::
make
(
base_offset
,
offset
,
no_op
));
offset
=
base_offset
;
}
}
Range
dom
=
dom_map
.
at
(
iv
);
if
(
!
offset
.
same_as
(
iv
->
var
))
{
// define the iv->var
nest
.
back
().
emplace_back
(
LetStmt
::
make
(
iv
->
var
,
offset
,
no_op
));
Stmt
MakeLoop
(
const
Stage
&
s
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
provide
,
Stmt
init
)
{
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
auto
nest
=
MakeLoopNest
(
s
,
dom_map
,
0
,
false
,
&
value_map
,
{});
provide
=
Substitute
(
provide
,
value_map
);
if
(
init
.
defined
())
{
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std
::
unordered_map
<
IterVar
,
int
>
reduce_state
;
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
for
(
IterVar
iv
:
compute
->
reduce_axis
)
{
reduce_state
[
iv
]
=
2
;
}
for
(
IterVar
iv
:
compute
->
axis
)
{
reduce_state
[
iv
]
=
1
;
}
// find which iter var is related to reduction and which is related to axis.
PassDownFlag
(
s
,
&
reduce_state
);
auto
leaf_iter_vars
=
s
->
leaf_iter_vars
;
std
::
unordered_map
<
IterVar
,
Expr
>
init_value_map
;
// first first loop that is related to reduction.
size_t
begin_loop
=
leaf_iter_vars
.
size
();
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
int
flag
=
reduce_state
.
at
(
iv
);
if
((
flag
&
2
)
!=
0
)
{
begin_loop
=
i
;
break
;
}
Expr
condition
=
(
iv
->
var
-
dom
->
min
)
<
dom
->
extent
;
// Boundary condition checking
// Need better boundary condition here.
nest
.
back
().
emplace_back
(
IfThenElse
::
make
(
condition
,
no_op
));
init_value_map
[
iv
]
=
value_map
.
at
(
iv
);
}
// skip loops that does not relates to axis.
std
::
unordered_map
<
IterVar
,
bool
>
skip_iter
;
for
(
size_t
i
=
begin_loop
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
int
flag
=
reduce_state
.
at
(
iv
);
if
((
flag
&
1
)
==
0
)
skip_iter
[
iv
]
=
true
;
}
auto
init_nest
=
MakeLoopNest
(
s
,
dom_map
,
begin_loop
,
true
,
&
init_value_map
,
skip_iter
);
init
=
Substitute
(
init
,
init_value_map
);
init
=
MergeNest
(
init_nest
,
init
);
// common nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
nest
.
begin
(),
nest
.
begin
()
+
begin_loop
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
reduce
(
nest
.
begin
()
+
begin_loop
,
nest
.
end
());
provide
=
MergeNest
(
reduce
,
provide
);
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
}
else
{
return
MergeNest
(
nest
,
provide
);
}
return
nest
;
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param tensors The tensors generated by provide.
*/
Stmt
MakeProvide
(
const
ComputeOpNode
*
op
,
const
std
::
vector
<
Tensor
>&
tensors
)
{
Tensor
t
=
tensors
[
0
];
...
...
@@ -210,13 +234,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
return
Provide
::
make
(
t
->
op
,
t
->
value_index
,
op
->
body
,
args
);
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param dom_map The domain map
* \param tensors The tensors generated by provide.
* \param body The content of the pipeline.
*/
Stmt
MakeRealize
(
const
ComputeOpNode
*
op
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
vector
<
Tensor
>&
tensors
,
...
...
@@ -230,6 +247,38 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds
,
make_const
(
Bool
(
1
),
true
),
body
);
}
void
MakeReduction
(
const
ComputeOpNode
*
op
,
const
std
::
vector
<
Tensor
>&
tensors
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
*
init
,
Stmt
*
provide
)
{
Stmt
no_op
=
Evaluate
::
make
(
0
);
Tensor
t
=
tensors
[
0
];
std
::
vector
<
Stmt
>
nest
;
Array
<
Expr
>
args
;
for
(
IterVar
iv
:
op
->
axis
)
{
args
.
push_back
(
iv
->
var
);
}
const
Reduce
*
reduce
=
op
->
body
.
as
<
Reduce
>
();
CHECK
(
reduce
);
Expr
init_value
,
update_value
;
if
(
reduce
->
op
==
"Add"
)
{
init_value
=
make_zero
(
reduce
->
type
);
update_value
=
Add
::
make
(
t
(
args
),
reduce
->
source
);
}
else
if
(
reduce
->
op
==
"Max"
)
{
init_value
=
reduce
->
type
.
min
();
update_value
=
Max
::
make
(
t
(
args
),
reduce
->
source
);
}
else
if
(
reduce
->
op
==
"Min"
)
{
init_value
=
reduce
->
type
.
max
();
update_value
=
Min
::
make
(
t
(
args
),
reduce
->
source
);
}
else
{
LOG
(
FATAL
)
<<
"Unsupported reduction "
<<
reduce
->
op
;
}
*
init
=
Provide
::
make
(
t
->
op
,
t
->
value_index
,
init_value
,
args
);
*
provide
=
Provide
::
make
(
t
->
op
,
t
->
value_index
,
update_value
,
args
);
}
Stmt
MakePipeline
(
const
Stage
&
sch
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
consumer
)
{
...
...
@@ -238,14 +287,20 @@ Stmt MakePipeline(const Stage& sch,
tensors
.
emplace_back
(
sch
->
op
.
output
(
i
));
}
Stmt
provide
;
if
(
sch
->
op
.
as
<
ComputeOpNode
>
())
{
provide
=
MakeProvide
(
sch
->
op
.
as
<
ComputeOpNode
>
(),
tensors
);
Stmt
init
,
provide
;
const
ComputeOpNode
*
compute
=
sch
->
op
.
as
<
ComputeOpNode
>
();
if
(
compute
)
{
if
(
compute
->
reduce_axis
.
size
()
==
0
)
{
provide
=
MakeProvide
(
compute
,
tensors
);
}
else
{
MakeReduction
(
compute
,
tensors
,
dom_map
,
&
init
,
&
provide
);
}
}
else
{
LOG
(
FATAL
)
<<
"not supported op "
<<
sch
->
op
->
type_key
();
}
std
::
vector
<
std
::
vector
<
Stmt
>
>
nest
=
MakeLoopNest
(
sch
,
dom_map
);
Stmt
producer
=
M
ergeNest
(
nest
,
provide
);
Stmt
producer
=
M
akeLoop
(
sch
,
dom_map
,
provide
,
init
);
producer
=
ProducerConsumer
::
make
(
sch
->
op
,
true
,
producer
);
Stmt
pipeline
=
producer
;
...
...
@@ -306,7 +361,6 @@ Stmt InjectInline(const Operation op, Stmt body) {
return
Inline
(
body
,
op
,
args
,
compute
->
body
);
}
Stmt
ScheduleOps
(
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map
)
{
Stmt
body
=
Stmt
();
...
...
@@ -330,5 +384,5 @@ Stmt ScheduleOps(
return
body
;
}
}
// namespace
ir
}
// namespace
schedule
}
// namespace tvm
tests/python/integration/test_ewise.py
View file @
a2c8a29b
...
...
@@ -18,7 +18,8 @@ def test_add():
# one line to build the function.
codes
=
[]
fadd
=
tvm
.
build
(
s
,
args
=
[
A
,
B
,
C
],
fadd
=
tvm
.
build
(
s
,
args
=
[
A
,
B
,
C
],
target
=
"cuda"
,
name
=
"myadd"
,
record_codes
=
codes
)
for
c
in
codes
:
...
...
tests/python/integration/test_reduce.py
0 → 100644
View file @
a2c8a29b
import
tvm
import
numpy
as
np
def
test_sum
():
# graph
n
=
tvm
.
Var
(
'n'
)
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
k
=
tvm
.
IterVar
((
0
,
m
))
B
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
sum
(
A
[
i
,
k
],
axis
=
k
),
name
=
'B'
)
# schedule
s
=
tvm
.
Schedule
(
B
.
op
)
# create iter var and assign them tags.
num_thread
=
1
block_x
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.x"
)
thread_x
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.x"
)
_
,
x
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
num_thread
,
outer
=
block_x
)
_
,
x
=
s
[
B
]
.
split
(
x
,
outer
=
thread_x
)
tvm
.
init_opencl
()
codes
=
[]
fsum
=
tvm
.
build
(
s
,
args
=
[
A
,
B
],
target
=
"opencl"
,
name
=
"myadd"
,
record_codes
=
codes
)
for
c
in
codes
:
print
(
c
)
num_device
=
1
for
i
in
range
(
num_device
):
ctx
=
tvm
.
opencl
(
i
)
if
not
ctx
.
enabled
:
continue
# launch the kernel.
n
=
1028
m
=
129
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
m
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
B
.
dtype
),
ctx
)
fsum
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
np
.
sum
(
a
.
asnumpy
(),
axis
=
1
),
rtol
=
1e-4
)
if
__name__
==
"__main__"
:
test_sum
()
tests/python/unittest/test_codegen_device.py
View file @
a2c8a29b
...
...
@@ -18,8 +18,7 @@ def test_add_pipeline():
# compile to IR
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
...
...
tests/python/unittest/test_codegen_makeapi.py
View file @
a2c8a29b
...
...
@@ -10,12 +10,13 @@ def test_makeapi():
s
=
tvm
.
Schedule
(
C
.
op
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
B
:
Bb
,
C
:
Cb
})
num_packed_args
=
2
f
=
tvm
.
codegen
.
MakeAPI
(
stmt
,
"myadd"
,
[
n
,
Ab
,
Bb
,
Cb
],
num_packed_args
)
assert
(
f
.
handle_data_type
[
Ab
.
data
]
.
dtype
==
Ab
.
dtype
)
...
...
tests/python/unittest/test_lang_tensor.py
View file @
a2c8a29b
...
...
@@ -26,7 +26,7 @@ def test_tensor_reduce():
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
])
rv
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
rdom
=
rv
))
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
axis
=
rv
))
# json load save
C_json
=
tvm
.
save_json
(
C
)
C_loaded
=
tvm
.
load_json
(
C_json
)
...
...
tests/python/unittest/test_pass_storage_flatten.py
View file @
a2c8a29b
...
...
@@ -12,7 +12,7 @@ def test_flatten2():
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
...
...
tests/python/unittest/test_
pass
_schedule_ops.py
→
tests/python/unittest/test_
schedule
_schedule_ops.py
View file @
a2c8a29b
...
...
@@ -11,7 +11,7 @@ def test_schedule0():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
def
test_schedule1
():
...
...
@@ -24,7 +24,7 @@ def test_schedule1():
xo
,
xi
=
s
[
A1
]
.
split
(
A1
.
op
.
axis
[
0
],
8
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
def
test_schedule2
():
...
...
@@ -39,7 +39,7 @@ def test_schedule2():
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment