Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a2c8a29b
Commit
a2c8a29b
authored
Feb 02, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 02, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Improve bound inference, support reduce codegen. (#30)
parent
d4af7ad6
Hide whitespace changes
Inline
Side-by-side
Showing
31 changed files
with
1039 additions
and
438 deletions
+1039
-438
include/tvm/expr.h
+12
-9
include/tvm/ir.h
+3
-3
include/tvm/ir_pass.h
+10
-11
include/tvm/operation.h
+3
-0
include/tvm/schedule.h
+37
-0
include/tvm/schedule_pass.h
+9
-0
python/tvm/api.py
+18
-18
python/tvm/build.py
+4
-2
python/tvm/schedule.py
+8
-0
src/api/api_lang.cc
+6
-0
src/api/api_pass.cc
+0
-1
src/api/api_schedule.cc
+1
-0
src/codegen/codegen_c.cc
+3
-2
src/lang/ir.cc
+5
-5
src/lang/operation.cc
+10
-1
src/pass/ir_mutator.cc
+4
-4
src/pass/ir_visitor.cc
+1
-1
src/pass/simple_passes.cc
+22
-0
src/schedule/bound.cc
+53
-11
src/schedule/compute_expr.h
+109
-0
src/schedule/int_set.cc
+377
-216
src/schedule/int_set.h
+64
-17
src/schedule/schedule_lang.cc
+45
-2
src/schedule/schedule_ops.cc
+180
-126
tests/python/integration/test_ewise.py
+2
-1
tests/python/integration/test_reduce.py
+45
-0
tests/python/unittest/test_codegen_device.py
+1
-2
tests/python/unittest/test_codegen_makeapi.py
+2
-1
tests/python/unittest/test_lang_tensor.py
+1
-1
tests/python/unittest/test_pass_storage_flatten.py
+1
-1
tests/python/unittest/test_schedule_schedule_ops.py
+3
-3
No files found.
include/tvm/expr.h
View file @
a2c8a29b
...
@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
...
@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using
Halide
::
Internal
::
Variable
;
using
Halide
::
Internal
::
Variable
;
using
Halide
::
Internal
::
make_const
;
using
Halide
::
Internal
::
make_const
;
using
Halide
::
Internal
::
make_zero
;
using
Halide
::
Internal
::
as_const_int
;
using
Halide
::
Internal
::
as_const_uint
;
inline
Type
TVMType2Type
(
TVMType
t
)
{
inline
Type
TVMType2Type
(
TVMType
t
)
{
...
@@ -126,25 +129,25 @@ using Halide::abs;
...
@@ -126,25 +129,25 @@ using Halide::abs;
using
Halide
::
select
;
using
Halide
::
select
;
/*!
/*!
* \brief sum of of source expression over
rdom
* \brief sum of of source expression over
axis
* \param source The source expression.
* \param source The source expression.
* \param
rdom
List of iteration variables that will be used for reduction.
* \param
axis
List of iteration variables that will be used for reduction.
*/
*/
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
rdom
);
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
axis
);
/*!
/*!
* \brief max of of source expression over
rdom
* \brief max of of source expression over
axis
* \param source The source expression.
* \param source The source expression.
* \param
rdom
List of iteration variables that will be used for reduction.
* \param
axis
List of iteration variables that will be used for reduction.
*/
*/
Expr
max
(
Expr
source
,
Array
<
IterVar
>
rdom
);
Expr
max
(
Expr
source
,
Array
<
IterVar
>
axis
);
/*!
/*!
* \brief max of of source expression over
rdom
* \brief max of of source expression over
axis
* \param source The source expression.
* \param source The source expression.
* \param
rdom
List of iteration variables that will be used for reduction.
* \param
axis
List of iteration variables that will be used for reduction.
*/
*/
Expr
min
(
Expr
source
,
Array
<
IterVar
>
rdom
);
Expr
min
(
Expr
source
,
Array
<
IterVar
>
axis
);
// print functions for expr
// print functions for expr
...
...
include/tvm/ir.h
View file @
a2c8a29b
...
@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
...
@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std
::
string
op
;
std
::
string
op
;
/*! \brief The source operand */
/*! \brief The source operand */
Expr
source
;
Expr
source
;
/*! \brief The reduction
domain
s */
/*! \brief The reduction
axi
s */
Array
<
IterVar
>
rdom
;
Array
<
IterVar
>
axis
;
/*! \brief construct expr from op and rdom */
/*! \brief construct expr from op and rdom */
static
Expr
make
(
std
::
string
op
,
Expr
src
,
Array
<
IterVar
>
rdom
);
static
Expr
make
(
std
::
string
op
,
Expr
src
,
Array
<
IterVar
>
rdom
);
...
@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
...
@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v
->
Visit
(
"dtype"
,
&
type
);
v
->
Visit
(
"dtype"
,
&
type
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"source"
,
&
source
);
v
->
Visit
(
"source"
,
&
source
);
v
->
Visit
(
"
rdom"
,
&
rdom
);
v
->
Visit
(
"
axis"
,
&
axis
);
}
}
static
const
IRNodeType
_type_info
=
IRNodeType
::
ExtensionExpr
;
static
const
IRNodeType
_type_info
=
IRNodeType
::
ExtensionExpr
;
static
constexpr
const
char
*
_type_key
=
"Reduce"
;
static
constexpr
const
char
*
_type_key
=
"Reduce"
;
...
...
include/tvm/ir_pass.h
View file @
a2c8a29b
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* \file ir_pass.h
* \file ir_pass.h
* \brief Collection of IR pass functions
* \brief Collection of IR pass functions
*
*
*
All
the pass functions in this file are for Stmt,
*
When
the pass functions in this file are for Stmt,
*
W
e can use PassFunction(Evaluate(expr)) to apply it to Expr
*
w
e can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
*/
#ifndef TVM_IR_PASS_H_
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
...
@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) {
...
@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) {
}
}
/*!
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
* That is: each VarExpr is defined and assigned once(in Let/For)
*
*
...
@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e);
...
@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e);
Stmt
ConvertSSA
(
Stmt
stmt
);
Stmt
ConvertSSA
(
Stmt
stmt
);
/*!
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
IterVar
,
Expr
>&
value_map
);
/*!
* \brief inline all calls of f in stmt.
* \brief inline all calls of f in stmt.
*
*
* \param f The function reference to be inlined
* \param f The function reference to be inlined
...
...
include/tvm/operation.h
View file @
a2c8a29b
...
@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
...
@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public
:
public
:
/*! \brief IterVar on each axis */
/*! \brief IterVar on each axis */
Array
<
IterVar
>
axis
;
Array
<
IterVar
>
axis
;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array
<
IterVar
>
reduce_axis
;
/*! \brief the compute expression */
/*! \brief the compute expression */
Expr
body
;
Expr
body
;
/*! \brief constructor */
/*! \brief constructor */
...
@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
...
@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"axis"
,
&
axis
);
v
->
Visit
(
"axis"
,
&
axis
);
v
->
Visit
(
"reduce_axis"
,
&
reduce_axis
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"body"
,
&
body
);
}
}
static
Operation
make
(
std
::
string
name
,
static
Operation
make
(
std
::
string
name
,
...
...
include/tvm/schedule.h
View file @
a2c8a29b
...
@@ -123,6 +123,8 @@ class Stage : public NodeRef {
...
@@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
Expr
x_factor
,
Expr
y_factor
);
Expr
x_factor
,
Expr
y_factor
);
// declare container type
using
ContainerType
=
StageNode
;
};
};
/*!
/*!
...
@@ -153,10 +155,21 @@ class Schedule : public NodeRef {
...
@@ -153,10 +155,21 @@ class Schedule : public NodeRef {
return
this
->
operator
[](
tensor
->
op
);
return
this
->
operator
[](
tensor
->
op
);
}
}
/*!
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void
normalize
();
/*!
* \brief access the internal node container
* \brief access the internal node container
* \return the pointer to the internal node container
* \return the pointer to the internal node container
*/
*/
inline
const
ScheduleNode
*
operator
->
()
const
;
inline
const
ScheduleNode
*
operator
->
()
const
;
// declare container type
using
ContainerType
=
ScheduleNode
;
};
};
/*!
/*!
...
@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
...
@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO
(
FuseNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
FuseNode
);
};
};
/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class
RebaseNode
:
public
IterVarRelationNode
{
public
:
/*! \brief The parent domain */
IterVar
parent
;
/*! \brief The inner domain */
IterVar
rebased
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"parent"
,
&
parent
);
v
->
Visit
(
"rebased"
,
&
rebased
);
}
static
IterVarRelation
make
(
IterVar
parent
,
IterVar
rebased
);
static
constexpr
const
char
*
_type_key
=
"Rebase"
;
TVM_DECLARE_NODE_TYPE_INFO
(
RebaseNode
);
};
// implementations
// implementations
inline
const
StageNode
*
Stage
::
operator
->
()
const
{
inline
const
StageNode
*
Stage
::
operator
->
()
const
{
return
static_cast
<
const
StageNode
*>
(
node_
.
get
());
return
static_cast
<
const
StageNode
*>
(
node_
.
get
());
...
...
include/tvm/schedule_pass.h
View file @
a2c8a29b
...
@@ -24,6 +24,15 @@ namespace schedule {
...
@@ -24,6 +24,15 @@ namespace schedule {
*/
*/
Map
<
IterVar
,
Range
>
InferBound
(
Schedule
sch
);
Map
<
IterVar
,
Range
>
InferBound
(
Schedule
sch
);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
);
}
// namespace schedule
}
// namespace schedule
}
// namespace tvm
}
// namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
#endif // TVM_SCHEDULE_PASS_H_
python/tvm/api.py
View file @
a2c8a29b
...
@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
...
@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return
_api_internal
.
_IterVar
(
dom
,
name
,
thread_tag
)
return
_api_internal
.
_IterVar
(
dom
,
name
,
thread_tag
)
def
sum
(
expr
,
rdom
):
def
sum
(
expr
,
axis
):
"""Create a sum expression over
rdom
"""Create a sum expression over
axis
Parameters
Parameters
----------
----------
expr : Expr
expr : Expr
The source expression.
The source expression.
rdom : RDomain
axis : IterVar
The reduction
domainx
The reduction
IterVar axis
"""
"""
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
axis
=
axis
if
isinstance
(
axis
,
list
)
else
[
axis
]
x
=
_make
.
Reduce
(
"Add"
,
expr
,
rdom
)
x
=
_make
.
Reduce
(
"Add"
,
expr
,
axis
)
return
x
return
x
def
min
(
expr
,
rdom
):
def
min
(
expr
,
axis
):
"""Create a min expression over
rdom
"""Create a min expression over
axis
Parameters
Parameters
----------
----------
expr : Expr
expr : Expr
The source expression.
The source expression.
rdom : RDomain
axis : IterVar
The reduction
domainx
The reduction
IterVar axis
"""
"""
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
axis
=
axis
if
isinstance
(
axis
,
list
)
else
[
axis
]
x
=
_make
.
Reduce
(
"Min"
,
expr
,
rdom
)
x
=
_make
.
Reduce
(
"Min"
,
expr
,
axis
)
return
x
return
x
def
max
(
expr
,
rdom
):
def
max
(
expr
,
axis
):
"""Create a min expression over
rdom
"""Create a min expression over
axis
Parameters
Parameters
----------
----------
expr : Expr
expr : Expr
The source expression.
The source expression.
rdom : RDomain
axis : IterVar
The reduction
domainx
The reduction
IterVar axis
"""
"""
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
axis
=
axis
if
isinstance
(
axis
,
list
)
else
[
axis
]
x
=
_make
.
Reduce
(
"Max"
,
expr
,
rdom
)
x
=
_make
.
Reduce
(
"Max"
,
expr
,
axis
)
return
x
return
x
...
...
python/tvm/build.py
View file @
a2c8a29b
...
@@ -62,9 +62,10 @@ def build(sch,
...
@@ -62,9 +62,10 @@ def build(sch,
# lowering
# lowering
bounds
=
schedule
.
InferBound
(
sch
)
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
ir_pass
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
print
(
stmt
)
fapi
=
codegen
.
MakeAPI
(
stmt
,
name
,
arg_list
,
len
(
arg_list
))
fapi
=
codegen
.
MakeAPI
(
stmt
,
name
,
arg_list
,
len
(
arg_list
))
fsplits
=
codegen
.
SplitHostDevice
(
fapi
)
fsplits
=
codegen
.
SplitHostDevice
(
fapi
)
...
@@ -73,7 +74,8 @@ def build(sch,
...
@@ -73,7 +74,8 @@ def build(sch,
for
i
,
f
in
enumerate
(
fsplits
):
for
i
,
f
in
enumerate
(
fsplits
):
t
=
target
if
i
>=
1
else
"c"
t
=
target
if
i
>=
1
else
"c"
record_codes
.
append
(
codegen
.
CompileToC
(
f
,
output_ssa
,
t
))
record_codes
.
append
(
codegen
.
CompileToC
(
f
,
output_ssa
,
t
))
for
c
in
record_codes
:
print
(
c
)
if
target
==
"cuda"
:
if
target
==
"cuda"
:
ret
=
codegen
.
BuildNVRTC
(
fsplits
,
"stackvm"
)
ret
=
codegen
.
BuildNVRTC
(
fsplits
,
"stackvm"
)
elif
target
==
"opencl"
:
elif
target
==
"opencl"
:
...
...
python/tvm/schedule.py
View file @
a2c8a29b
...
@@ -33,6 +33,14 @@ class Schedule(NodeBase):
...
@@ -33,6 +33,14 @@ class Schedule(NodeBase):
raise
ValueError
(
"Cannot find the operation
%
s in schedule"
%
(
str
(
k
)))
raise
ValueError
(
"Cannot find the operation
%
s in schedule"
%
(
str
(
k
)))
return
self
.
stage_map
[
k
]
return
self
.
stage_map
[
k
]
def
normalize
(
self
):
"""Build a normalized schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal
.
_ScheduleNormalize
(
self
)
@register_node
@register_node
class
Stage
(
NodeBase
):
class
Stage
(
NodeBase
):
"""A Stage represents schedule for one operation."""
"""A Stage represents schedule for one operation."""
...
...
src/api/api_lang.cc
View file @
a2c8a29b
...
@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
...
@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
});
TVM_REGISTER_API
(
_ScheduleNormalize
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Schedule
()
.
normalize
();
});
}
// namespace tvm
}
// namespace tvm
src/api/api_pass.cc
View file @
a2c8a29b
...
@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
...
@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1
(
ConvertSSA
);
REGISTER_PASS1
(
ConvertSSA
);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS2
(
ScheduleOps
);
REGISTER_PASS2
(
StorageFlatten
);
REGISTER_PASS2
(
StorageFlatten
);
}
// namespace ir
}
// namespace ir
...
...
src/api/api_schedule.cc
View file @
a2c8a29b
...
@@ -29,6 +29,7 @@ namespace schedule {
...
@@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1
(
InferBound
);
REGISTER_SCHEDULE_PASS1
(
InferBound
);
REGISTER_SCHEDULE_PASS1
(
CreateReadGraph
);
REGISTER_SCHEDULE_PASS1
(
CreateReadGraph
);
REGISTER_SCHEDULE_PASS2
(
PostDFSOrder
);
REGISTER_SCHEDULE_PASS2
(
PostDFSOrder
);
REGISTER_SCHEDULE_PASS2
(
ScheduleOps
);
}
// namespace schedule
}
// namespace schedule
}
// namespace tvm
}
// namespace tvm
src/codegen/codegen_c.cc
View file @
a2c8a29b
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
* \file codegen_c.cc
*/
*/
#include <iomanip>
#include "./codegen_c.h"
#include "./codegen_c.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
...
@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch
(
op
->
type
.
bits
())
{
switch
(
op
->
type
.
bits
())
{
case
64
:
case
32
:
{
case
64
:
case
32
:
{
std
::
ostringstream
temp
;
std
::
ostringstream
temp
;
temp
<<
op
->
value
;
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
type
.
bits
()
==
32
)
temp
<<
'f'
;
if
(
op
->
type
.
bits
()
==
32
)
temp
<<
'f'
;
p
->
MarkConst
(
temp
.
str
());
p
->
MarkConst
(
temp
.
str
());
os
<<
temp
.
str
();
os
<<
temp
.
str
();
...
@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
...
@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case
16
:
{
case
16
:
{
os
<<
'('
;
os
<<
'('
;
p
->
PrintType
(
op
->
type
,
os
);
p
->
PrintType
(
op
->
type
,
os
);
os
<<
')'
<<
op
->
value
<<
'f'
;
os
<<
')'
<<
std
::
scientific
<<
op
->
value
<<
'f'
;
break
;
break
;
}
}
default
:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
type
<<
"
\n
"
;
default
:
LOG
(
FATAL
)
<<
"Bad bit-width for float: "
<<
op
->
type
<<
"
\n
"
;
...
...
src/lang/ir.cc
View file @
a2c8a29b
...
@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<<
op
->
op
<<
op
->
op
<<
", "
;
<<
", "
;
p
->
print
(
op
->
source
);
p
->
print
(
op
->
source
);
p
->
stream
<<
",
rdom="
<<
op
->
rdom
<<
")"
;
p
->
stream
<<
",
axis="
<<
op
->
axis
<<
")"
;
});
});
}
// namespace Internal
}
// namespace Internal
...
@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
Array
<
IterVar
>
rdom
)
{
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
Array
<
IterVar
>
axis
)
{
auto
n
=
std
::
make_shared
<
Reduce
>
();
auto
n
=
std
::
make_shared
<
Reduce
>
();
CHECK
(
source
.
defined
());
CHECK
(
source
.
defined
());
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
CHECK
(
rdom
[
i
].
defined
());
CHECK
(
axis
[
i
].
defined
());
}
}
n
->
type
=
source
.
type
();
n
->
type
=
source
.
type
();
n
->
source
=
source
;
n
->
source
=
source
;
n
->
op
=
op
;
n
->
op
=
op
;
n
->
rdom
=
rdom
;
n
->
axis
=
axis
;
return
Expr
(
n
);
return
Expr
(
n
);
}
}
...
...
src/lang/operation.cc
View file @
a2c8a29b
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
*/
*/
#include <tvm/operation.h>
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <memory>
#include <memory>
namespace
tvm
{
namespace
tvm
{
...
@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
...
@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
// ComputeOpNode
// ComputeOpNode
Array
<
IterVar
>
ComputeOpNode
::
root_iter_vars
()
const
{
Array
<
IterVar
>
ComputeOpNode
::
root_iter_vars
()
const
{
return
axis
;
if
(
reduce_axis
.
size
()
==
0
)
return
axis
;
Array
<
IterVar
>
ret
=
axis
;
for
(
IterVar
iv
:
reduce_axis
)
{
ret
.
push_back
(
iv
);
}
return
ret
;
}
}
Type
ComputeOpNode
::
output_dtype
(
size_t
i
)
const
{
Type
ComputeOpNode
::
output_dtype
(
size_t
i
)
const
{
...
@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
...
@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
n
->
name
=
name
;
n
->
name
=
name
;
n
->
axis
=
axis
;
n
->
axis
=
axis
;
n
->
body
=
body
;
n
->
body
=
body
;
if
(
n
->
body
->
is_type
<
ir
::
Reduce
>
())
{
n
->
reduce_axis
=
n
->
body
.
as
<
ir
::
Reduce
>
()
->
axis
;
}
return
Operation
(
n
);
return
Operation
(
n
);
}
}
...
...
src/pass/ir_mutator.cc
View file @
a2c8a29b
...
@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
...
@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
}
}
}
}
inline
Array
<
IterVar
>
Mutate
RDom
(
Array
<
IterVar
>
rdom
,
IRMutator
*
m
)
{
inline
Array
<
IterVar
>
Mutate
IterVarArr
(
Array
<
IterVar
>
rdom
,
IRMutator
*
m
)
{
std
::
vector
<
IterVar
>
new_dom
(
rdom
.
size
());
std
::
vector
<
IterVar
>
new_dom
(
rdom
.
size
());
bool
changed
=
false
;
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
i
++
)
{
...
@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
...
@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
Array
<
IterVar
>
new_
rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Array
<
IterVar
>
new_
axis
=
MutateIterVarArr
(
op
->
axis
,
m
);
Expr
new_source
=
m
->
Mutate
(
op
->
source
);
Expr
new_source
=
m
->
Mutate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
if
(
op
->
axis
.
same_as
(
new_axis
)
&&
op
->
source
.
same_as
(
new_source
))
{
op
->
source
.
same_as
(
new_source
))
{
return
e
;
return
e
;
}
else
{
}
else
{
return
Reduce
::
make
(
op
->
op
,
new_source
,
new_
rdom
);
return
Reduce
::
make
(
op
->
op
,
new_source
,
new_
axis
);
}
}
});
});
...
...
src/pass/ir_visitor.cc
View file @
a2c8a29b
...
@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
...
@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
IRVisitor
*
v
)
{
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
IRVisitor
*
v
)
{
VisitRDom
(
op
->
rdom
,
v
);
VisitRDom
(
op
->
axis
,
v
);
v
->
Visit
(
op
->
source
);
v
->
Visit
(
op
->
source
);
})
})
.
set_dispatch
<
IntImm
>
(
NoOp
)
.
set_dispatch
<
IntImm
>
(
NoOp
)
...
...
src/pass/simple_passes.cc
View file @
a2c8a29b
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
*/
*/
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
namespace
tvm
{
namespace
tvm
{
...
@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
...
@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
v
.
Visit
(
e
);
v
.
Visit
(
e
);
return
v
.
has_side_effect_
;
return
v
.
has_side_effect_
;
}
}
class
IRSubstitue
:
public
IRMutator
{
public
:
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
smap
.
find
(
op
);
if
(
it
!=
smap
.
end
())
{
return
it
->
second
;
}
else
{
return
e
;
}
}
std
::
unordered_map
<
const
Variable
*
,
Expr
>
smap
;
};
Stmt
Substitute
(
Stmt
stmt
,
const
Map
<
IterVar
,
Expr
>&
value_map
)
{
IRSubstitue
m
;
for
(
auto
kv
:
value_map
)
{
m
.
smap
[
kv
.
first
->
var
.
get
()]
=
kv
.
second
;
}
return
m
.
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
src/schedule/bound.cc
View file @
a2c8a29b
...
@@ -54,6 +54,11 @@ void PassDown(const Stage& s,
...
@@ -54,6 +54,11 @@ void PassDown(const Stage& s,
const
Range
&
range_inner
=
state
.
at
(
r
->
inner
);
const
Range
&
range_inner
=
state
.
at
(
r
->
inner
);
state
[
r
->
fused
]
=
Range
::
make_with_min_extent
(
state
[
r
->
fused
]
=
Range
::
make_with_min_extent
(
0
,
range_outer
->
extent
*
range_inner
->
extent
);
0
,
range_outer
->
extent
*
range_inner
->
extent
);
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
r
=
rel
.
as
<
RebaseNode
>
();
CHECK
(
state
.
count
(
r
->
parent
));
state
[
r
->
rebased
]
=
Range
::
make_with_min_extent
(
0
,
state
.
at
(
r
->
parent
)
->
extent
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
}
...
@@ -85,6 +90,13 @@ void PassUp(const Stage& s,
...
@@ -85,6 +90,13 @@ void PassUp(const Stage& s,
&
outer
,
&
inner
);
&
outer
,
&
inner
);
state
[
r
->
outer
]
=
outer
;
state
[
r
->
outer
]
=
outer
;
state
[
r
->
inner
]
=
inner
;
state
[
r
->
inner
]
=
inner
;
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
IntSet
parent
;
const
RebaseNode
*
r
=
rel
.
as
<
RebaseNode
>
();
PassUp
(
r
,
dom_map
,
state
.
at
(
r
->
rebased
),
&
parent
);
state
[
r
->
parent
]
=
parent
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
}
...
@@ -109,9 +121,15 @@ void PassToOperation(
...
@@ -109,9 +121,15 @@ void PassToOperation(
// Eventually, we need to change the inference to be a Pull style inference
// Eventually, we need to change the inference to be a Pull style inference
if
(
tensor
->
op
.
as
<
ComputeOpNode
>
())
{
if
(
tensor
->
op
.
as
<
ComputeOpNode
>
())
{
auto
root_iter_vars
=
tensor
->
op
->
root_iter_vars
();
auto
root_iter_vars
=
tensor
->
op
->
root_iter_vars
();
CHECK_EQ
(
tensor
.
ndim
(),
root_iter_vars
.
size
());
const
ComputeOpNode
*
op
=
tensor
->
op
.
as
<
ComputeOpNode
>
();
for
(
size_t
i
=
0
;
i
<
tensor
.
ndim
();
++
i
)
{
CHECK_EQ
(
op
->
axis
.
size
()
+
op
->
reduce_axis
.
size
(),
root_iter_vars
.
size
());
(
*
result
)[
root_iter_vars
[
i
]].
push_back
(
dim_bounds
[
i
]);
for
(
size_t
i
=
0
;
i
<
op
->
axis
.
size
();
++
i
)
{
(
*
result
)[
op
->
axis
[
i
]].
push_back
(
dim_bounds
[
i
]);
}
// reduction.
for
(
size_t
i
=
0
;
i
<
op
->
reduce_axis
.
size
();
++
i
)
{
(
*
result
)[
op
->
reduce_axis
[
i
]].
push_back
(
IntSet
::
range
(
op
->
reduce_axis
[
i
]
->
dom
));
}
}
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown operation mode "
<<
tensor
->
op
->
type_key
();
LOG
(
FATAL
)
<<
"unknown operation mode "
<<
tensor
->
op
->
type_key
();
...
@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
...
@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{
"local"
,
2
}
{
"local"
,
2
}
};
};
static
std
::
unordered_map
<
std
::
string
,
int
>
thread_tag_rank
{
static
std
::
unordered_map
<
std
::
string
,
int
>
thread_tag_rank
{
{
"
grid
Idx.x"
,
0
},
{
"
block
Idx.x"
,
0
},
{
"
grid
Idx.y"
,
0
},
{
"
block
Idx.y"
,
0
},
{
"
grid
Idx.z"
,
0
},
{
"
block
Idx.z"
,
0
},
{
"threadIdx.x"
,
1
},
{
"threadIdx.x"
,
1
},
{
"threadIdx.y"
,
1
},
{
"threadIdx.y"
,
1
},
{
"threadIdx.z"
,
1
}
{
"threadIdx.z"
,
1
}
...
@@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
...
@@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
(
*
rmap
)[
iv
]
=
iv
->
dom
;
(
*
rmap
)[
iv
]
=
iv
->
dom
;
}
}
}
}
// get range of all child iter vars.
PassDown
(
stage
,
rmap
);
if
(
stage
->
attach_type
==
kScope
)
{
if
(
stage
->
attach_type
==
kScope
)
{
Stage
parent
=
stage
->
attach_stage
;
Stage
parent
=
stage
->
attach_stage
;
...
@@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
...
@@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
bool
fix_value
=
true
;
bool
fix_value
=
true
;
for
(
auto
iv
:
parent
->
leaf_iter_vars
)
{
for
(
auto
iv
:
parent
->
leaf_iter_vars
)
{
Range
vrange
=
rmap
->
at
(
iv
);
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
"call schedule.normalize to achieve this."
;
// special optimization to remove trivial loop
if
(
is_one
(
vrange
->
extent
))
{
up_state
[
iv
]
=
IntSet
::
single_point
(
vrange
->
min
);
}
if
(
fix_value
&&
!
ScopeRelax
(
iv
,
stage
->
scope
))
{
if
(
fix_value
&&
!
ScopeRelax
(
iv
,
stage
->
scope
))
{
up_state
[
iv
]
=
IntSet
::
mak
e_point
(
iv
->
var
);
up_state
[
iv
]
=
IntSet
::
singl
e_point
(
iv
->
var
);
}
else
{
}
else
{
up_state
[
iv
]
=
IntSet
::
make_range
(
rmap
->
at
(
iv
)
);
up_state
[
iv
]
=
IntSet
::
range
(
vrange
);
}
}
if
(
stage
->
attach_ivar
==
iv
)
{
if
(
stage
->
attach_ivar
==
iv
)
{
fix_value
=
false
;
fix_value
=
false
;
...
@@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
...
@@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
bp_state
[
iv
]
=
{
up_state
.
at
(
iv
)};
bp_state
[
iv
]
=
{
up_state
.
at
(
iv
)};
}
}
auto
result
=
BoundProp
(
post_order
,
&
bp_state
);
auto
result
=
BoundProp
(
post_order
,
&
bp_state
);
// Set relaxation
Map
<
IterVar
,
IntSet
>
relax_set
;
Stage
s
=
stage
;
while
(
s
->
attach_type
==
kScope
)
{
s
=
s
->
attach_stage
;
for
(
auto
iv
:
s
->
leaf_iter_vars
)
{
if
(
ScopeRelax
(
iv
,
stage
->
scope
))
{
relax_set
.
Set
(
iv
,
IntSet
::
range
(
rmap
->
at
(
iv
)));
}
}
}
for
(
auto
iv
:
stage
->
op
->
root_iter_vars
())
{
for
(
auto
iv
:
stage
->
op
->
root_iter_vars
())
{
CHECK
(
result
.
count
(
iv
));
CHECK
(
result
.
count
(
iv
));
CHECK
(
!
rmap
->
count
(
iv
));
CHECK
(
!
rmap
->
count
(
iv
));
(
*
rmap
)[
iv
]
=
result
.
at
(
iv
).
GetCoverRange
();
Range
r
=
result
.
at
(
iv
).
cover_range
(
iv
->
dom
);
if
(
relax_set
.
size
()
!=
0
)
{
r
=
EvalSet
(
r
,
relax_set
).
cover_range
(
iv
->
dom
);
}
(
*
rmap
)[
iv
]
=
r
;
}
}
}
}
// get range of all child iter vars.
PassDown
(
stage
,
rmap
);
}
}
...
...
src/schedule/compute_expr.h
0 → 100644
View file @
a2c8a29b
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace
tvm
{
namespace
schedule
{
using
Halide
::
Internal
::
add_would_overflow
;
using
Halide
::
Internal
::
sub_would_overflow
;
using
Halide
::
Internal
::
mul_would_overflow
;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \return The result.
*/
template
<
typename
OP
>
inline
Expr
ComputeExpr
(
Expr
lhs
,
Expr
rhs
)
{
return
OP
::
make
(
lhs
,
rhs
);
}
template
<
typename
T
>
inline
bool
GetConst
(
Expr
e
,
T
*
out
);
template
<>
bool
GetConst
<
int64_t
>
(
Expr
e
,
int64_t
*
out
)
{
if
(
e
.
type
().
is_vector
())
return
false
;
const
int64_t
*
v
=
as_const_int
(
e
);
if
(
v
)
{
*
out
=
*
v
;
return
true
;
}
else
{
return
false
;
}
}
template
<>
bool
GetConst
<
uint64_t
>
(
Expr
e
,
uint64_t
*
out
)
{
if
(
e
.
type
().
is_vector
())
return
false
;
const
uint64_t
*
v
=
as_const_uint
(
e
);
if
(
v
)
{
*
out
=
*
v
;
return
true
;
}
else
{
return
false
;
}
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua + ub); \
} \
template
<>
inline
Expr
ComputeExpr
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
a
))
return
b
;
if
(
is_zero
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
add
,
+
);
return
ir
::
Add
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
if
(
is_zero
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
sub
,
-
);
return
ir
::
Add
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
if
(
is_one
(
a
))
return
b
;
if
(
is_one
(
b
))
return
a
;
TVM_CONST_PROPAGATION
(
mul
,
*
);
return
ir
::
Mul
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
if
(
is_one
(
b
))
return
a
;
return
ir
::
Mul
::
make
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
return
Halide
::
Internal
::
Interval
::
make_max
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
return
Halide
::
Internal
::
Interval
::
make_min
(
a
,
b
);
}
}
// namespace schedule
}
// namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
src/schedule/int_set.cc
View file @
a2c8a29b
/*!
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016 by Contributors
* \file int_set.cc
* \file int_set
_impl
.cc
* \brief The integer set functions
* \brief The integer set functions
*/
*/
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <pass/Interval.h>
#include "./int_set.h"
#include "./int_set.h"
#include "./compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
schedule
{
namespace
schedule
{
using
Halide
::
Internal
::
Interval
;
using
namespace
ir
;
using
namespace
ir
;
/*! \brief Set of continuous interval */
struct
IntervalSet
:
public
IntSetNode
{
/*! \brief the internal interval*/
Interval
i
;
static
IntSet
make
(
Interval
i
)
{
std
::
shared_ptr
<
IntervalSet
>
n
=
std
::
make_shared
<
IntervalSet
>
();
n
->
i
=
i
;
return
IntSet
(
n
);
}
static
IntSet
make
(
Expr
min
,
Expr
max
)
{
std
::
shared_ptr
<
IntervalSet
>
n
=
std
::
make_shared
<
IntervalSet
>
();
n
->
i
.
min
=
min
;
n
->
i
.
max
=
max
;
return
IntSet
(
n
);
}
static
constexpr
const
char
*
_type_key
=
"IntervalSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntervalSet
);
};
/*!
/*!
* \brief Internal node container of int set.
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
*/
class
IntSetNode
:
public
Node
{
struct
StrideSet
:
public
IntSetNode
{
public
:
/*! \brief the base inetrval */
/*! \brief The base range scope */
Interval
base
;
Range
base
;
/*! \brief additional extents in positive number */
/*! \brief additional strided domain */
Array
<
Expr
>
extents
;
Array
<
Range
>
domain
;
/*! \brief additional strides in positive number */
/*! \brief The stride of each strided domain */
Array
<
Expr
>
strides
;
Array
<
Expr
>
stride
;
/*!
static
constexpr
const
char
*
_type_key
=
"StrideSet"
;
* \brief The concrete set,
TVM_DECLARE_NODE_TYPE_INFO
(
StrideSet
);
* used when concrete execution is enabled.
*/
std
::
vector
<
int32_t
>
concrete
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"base"
,
&
base
);
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"stride"
,
&
stride
);
}
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntSetNode
);
};
};
TVM_REGISTER_NODE_TYPE
(
IntSetNode
);
inline
IntSet
IntSet
::
cover_interval
()
const
{
if
((
*
this
).
as
<
IntervalSet
>
())
return
*
this
;
const
StrideSet
*
s
=
(
*
this
).
as
<
StrideSet
>
();
if
(
s
)
{
CHECK_NE
(
s
->
extents
.
size
(),
0U
);
Expr
max
=
s
->
base
.
max
;
for
(
size_t
i
=
0
;
i
<
s
->
extents
.
size
();
++
i
)
{
max
=
max
+
s
->
extents
[
i
]
*
s
->
strides
[
i
]
-
s
->
strides
[
i
];
}
return
IntervalSet
::
make
(
s
->
base
.
min
,
max
);
}
LOG
(
FATAL
)
<<
"cannot convert set "
<<
(
*
this
)
->
type_key
()
<<
" to interval"
;
return
IntSet
::
everything
();
}
Range
IntSet
::
cover_range
(
Range
max_range
)
const
{
IntSet
temp
;
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
if
(
s_int
==
nullptr
)
{
temp
=
this
->
cover_interval
();
s_int
=
temp
.
as
<
IntervalSet
>
();
}
if
(
s_int
->
i
.
is_bounded
())
{
return
Range
::
make_with_min_extent
(
s_int
->
i
.
min
,
Simplify
(
s_int
->
i
.
max
+
1
-
s_int
->
i
.
min
));
}
return
max_range
;
}
namespace
{
bool
IntSet
::
is_everything
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_everything
());
}
inline
bool
Match
(
const
Expr
&
e
,
int64_t
value
)
{
bool
IntSet
::
is_single_point
()
const
{
const
ir
::
IntImm
*
v
=
e
.
as
<
ir
::
IntImm
>
();
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
v
!=
nullptr
&&
v
->
value
;
return
(
s_int
&&
s_int
->
i
.
is_single_point
())
;
}
}
// whether a exactly matches b.
IntSet
IntSet
::
everything
()
{
inline
bool
Match
(
const
IntSet
&
a
,
return
IntervalSet
::
make
(
Interval
::
everything
());
const
Range
&
b
)
{
if
(
a
->
base
==
b
&&
a
->
domain
.
size
()
==
0
&&
a
->
concrete
.
size
()
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
}
// whether a exactly matches b.
IntSet
IntSet
::
single_point
(
Expr
x
)
{
inline
bool
Match
(
const
IntSet
&
a
,
return
IntervalSet
::
make
(
Interval
::
single_point
(
x
));
const
Expr
&
b
)
{
if
(
a
->
domain
.
size
()
==
0
&&
a
->
concrete
.
size
()
==
0
)
{
return
Match
(
a
->
base
->
extent
,
1
)
&&
a
->
base
->
min
.
same_as
(
b
);
}
else
{
return
false
;
}
}
}
inline
bool
IsNumber
(
const
IntSet
&
s
)
{
IntSet
IntSet
::
range
(
Range
r
)
{
if
(
s
->
domain
.
size
()
!=
0
)
return
false
;
// must make sure it can be matched back by MatchRange.
if
(
s
->
concrete
.
size
()
!=
0
)
{
if
(
is_one
(
r
->
extent
))
{
return
s
->
concrete
.
size
()
==
1
;
return
IntSet
::
single_point
(
r
->
min
);
}
if
(
is_positive_const
(
r
->
extent
)
&&
is_const
(
r
->
min
))
{
return
IntervalSet
::
make
(
r
->
min
,
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
r
->
extent
,
r
->
min
),
1
));
}
}
return
Match
(
s
->
base
->
extent
,
1
);
return
IntervalSet
::
make
(
r
->
min
,
(
r
->
extent
+
r
->
min
)
-
1
);
}
}
inline
Expr
AsNumber
(
const
IntSet
&
s
)
{
// Check if a is created from b.
return
s
->
base
->
min
;
inline
bool
MatchRange
(
const
IntSet
&
a
,
const
Range
&
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
if
(
!
i
.
min
.
same_as
(
b
))
return
false
;
if
(
is_one
(
b
->
extent
))
return
i
.
is_single_point
();
if
(
is_positive_const
(
b
->
extent
)
&&
is_const
(
b
->
min
))
{
// deep equality
return
Equal
(
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
b
->
extent
,
b
->
min
),
1
),
a_int
->
i
.
max
);
}
const
Sub
*
sub
=
i
.
max
.
as
<
Sub
>
();
if
(
!
sub
)
return
false
;
if
(
is_one
(
sub
->
b
))
return
false
;
const
Add
*
add
=
sub
->
a
.
as
<
Add
>
();
return
add
&&
add
->
a
.
same_as
(
b
->
min
)
&&
add
->
b
.
same_as
(
b
->
extent
);
}
}
// set combination rule by operators
inline
bool
MatchPoint
(
const
IntSet
&
a
,
template
<
typename
T
>
const
Expr
&
b
)
{
inline
IntSet
BinaryCombine
(
IntSet
a
,
IntSet
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
LOG
(
WARNING
)
<<
"cannot evaluate binary op "
<<
T
::
_type_key
;
if
(
!
a_int
)
return
false
;
return
IntSet
::
make_all_set
();
const
Interval
&
i
=
a_int
->
i
;
return
i
.
is_single_point
()
&&
i
.
min
.
same_as
(
b
);
}
}
template
<>
IntSet
Union
(
const
Array
<
IntSet
>&
set
)
{
inline
IntSet
BinaryCombine
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
if
(
set
.
size
()
==
1
)
return
set
[
0
];
auto
n
=
std
::
make_shared
<
IntSetNode
>
(
*
(
a
.
operator
->
()));
Interval
x
=
set
[
0
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
for
(
size_t
i
=
0
;
i
<
b
->
domain
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
set
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
b
->
domain
[
i
]);
x
.
include
(
set
[
i
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
);
n
->
stride
.
push_back
(
b
->
stride
[
i
]);
}
if
(
IsNumber
(
a
))
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
b
->
base
->
extent
);
}
else
if
(
IsNumber
(
b
))
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
a
->
base
->
extent
);
}
else
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
a
->
base
->
extent
+
b
->
base
->
extent
-
1
);
}
}
return
Int
Set
(
n
);
return
Int
ervalSet
::
make
(
x
);
}
}
inline
Range
Negation
(
Range
a
)
{
// type traits
if
(
Match
(
a
->
extent
,
1
))
{
template
<
typename
OP
>
return
Range
::
make_with_min_extent
(
-
a
->
min
,
a
->
extent
);
struct
is_logical_op
{
}
else
{
static
const
bool
value
=
false
;
return
Range
::
make_with_min_extent
(
-
(
a
->
min
+
a
->
extent
-
1
),
a
->
extent
);
};
#define TVM_DECLARE_LOGICAL_OP(OP) \
template<> \
struct is_logical_op<ir::OP> { \
static const bool value = true; \
};
// interval related.
template
<
typename
OP
>
inline
IntSet
CombineInterval
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
OP
>
(
a
.
min
,
b
.
min
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval "
<<
OP
::
_type_key
;
return
IntSet
::
everything
();
}
}
inline
IntSet
Negation
(
IntSet
a
)
{
template
<>
CHECK_EQ
(
a
->
concrete
.
size
(),
0U
);
inline
IntSet
CombineInterval
<
Add
>
(
Interval
a
,
Interval
b
)
{
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
n
->
base
=
Negation
(
a
->
base
);
return
IntSet
::
single_point
(
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
));
for
(
size_t
i
=
0
;
i
<
a
->
domain
.
size
();
++
i
)
{
}
n
->
domain
.
push_back
(
Negation
(
a
->
domain
[
i
]));
Interval
r
=
Interval
::
everything
();
n
->
stride
.
push_back
(
a
->
stride
[
i
]);
if
(
a
.
has_lower_bound
()
&&
b
.
has_lower_bound
())
{
r
.
min
=
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
);
}
}
return
IntSet
(
a
);
if
(
a
.
has_upper_bound
()
&&
b
.
has_upper_bound
())
{
r
.
max
=
ComputeExpr
<
Add
>
(
a
.
max
,
b
.
max
);
}
return
IntervalSet
::
make
(
r
);
}
}
template
<>
template
<>
inline
IntSet
BinaryCombine
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
inline
IntSet
CombineInterval
<
Sub
>
(
Interval
a
,
Interval
b
)
{
return
BinaryCombine
<
Add
>
(
a
,
Negation
(
b
));
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
min
));
}
Interval
r
=
Interval
::
everything
();
if
(
a
.
has_lower_bound
()
&&
b
.
has_upper_bound
())
{
r
.
min
=
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
max
);
}
if
(
a
.
has_upper_bound
()
&&
b
.
has_lower_bound
())
{
r
.
max
=
ComputeExpr
<
Sub
>
(
a
.
max
,
b
.
min
);
}
return
IntervalSet
::
make
(
r
);
}
}
inline
IntSet
BinaryMul
(
IntSet
a
,
Expr
b
)
{
template
<>
// copy construct
inline
IntSet
CombineInterval
<
Mul
>
(
Interval
a
,
Interval
b
)
{
if
(
Match
(
b
,
1
))
return
a
;
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
if
(
Match
(
b
,
-
1
))
return
Negation
(
a
);
return
IntSet
::
single_point
(
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
));
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
}
n
->
base
=
Range
::
make_with_min_extent
(
0
,
1
);
if
(
a
.
is_single_point
()
&&
!
b
.
is_single_point
())
{
n
->
domain
.
push_back
(
a
->
base
);
std
::
swap
(
a
,
b
);
n
->
stride
.
push_back
(
b
);
}
for
(
size_t
i
=
0
;
i
<
a
->
domain
.
size
();
++
i
)
{
if
(
b
.
is_single_point
())
{
n
->
domain
.
push_back
(
a
->
domain
[
i
]);
if
(
is_zero
(
b
.
min
))
return
IntSet
::
single_point
(
0
);
n
->
stride
.
push_back
(
a
->
stride
[
i
]
*
b
);
if
(
is_one
(
b
.
min
))
return
IntervalSet
::
make
(
a
);
}
Expr
e1
=
a
.
has_lower_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
)
:
a
.
min
;
return
IntSet
(
a
);
Expr
e2
=
a
.
has_upper_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
max
,
b
.
min
)
:
a
.
max
;
// This is relaxiation
// TODO(tqchen): consider convert to StrideSet.
if
(
is_positive_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e1
,
e2
);
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
select
(
cmp
,
e1
,
e2
),
select
(
cmp
,
e2
,
e1
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
return
IntSet
::
everything
();
}
}
template
<>
template
<>
inline
IntSet
BinaryCombine
<
Mul
>
(
IntSet
a
,
IntSet
b
)
{
inline
IntSet
CombineInterval
<
Max
>
(
Interval
a
,
Interval
b
)
{
if
(
IsNumber
(
a
))
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
BinaryMul
(
a
,
AsNumber
(
b
));
return
IntSet
::
single_point
(
ComputeExpr
<
Max
>
(
a
.
min
,
b
.
min
));
}
else
if
(
IsNumber
(
b
))
{
return
BinaryMul
(
b
,
AsNumber
(
a
));
}
else
{
return
IntSet
::
make_all_set
();
}
}
return
IntervalSet
::
make
(
Interval
::
make_max
(
a
.
min
,
b
.
min
),
Interval
::
make_max
(
a
.
max
,
b
.
max
));
}
}
}
// namespace
template
<>
inline
IntSet
CombineInterval
<
Min
>
(
Interval
a
,
Interval
b
)
{
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
return
IntSet
::
single_point
(
ComputeExpr
<
Min
>
(
a
.
min
,
b
.
min
));
}
return
IntervalSet
::
make
(
Interval
::
make_min
(
a
.
min
,
b
.
min
),
Interval
::
make_min
(
a
.
max
,
b
.
max
));
}
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
template
<
typename
OP
>
.
set_dispatch
<
IntSetNode
>
([](
const
IntSetNode
*
op
,
IRPrinter
*
p
)
{
inline
IntSet
CombineInterval_
(
IntSet
a
,
IntSet
b
)
{
p
->
stream
<<
"int-set(base="
;
return
CombineInterval
<
OP
>
(
p
->
print
(
op
->
base
);
a
.
as
<
IntervalSet
>
()
->
i
,
b
.
as
<
IntervalSet
>
()
->
i
);
p
->
stream
<<
')'
;
}
});
IntSet
IntSet
::
make_range
(
Range
dom
)
{
// stride related
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
inline
IntSet
AsStrideSet
(
IntSet
a
)
{
n
->
base
=
dom
;
if
(
a
.
as
<
StrideSet
>
())
return
a
;
const
IntervalSet
*
s
=
a
.
as
<
IntervalSet
>
();
CHECK
(
s
->
i
.
is_bounded
());
std
::
shared_ptr
<
StrideSet
>
n
=
std
::
make_shared
<
StrideSet
>
();
n
->
base
=
s
->
i
;
return
IntSet
(
n
);
return
IntSet
(
n
);
}
}
template
<
typename
OP
>
inline
IntSet
CombineSets
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
.
cover_interval
());
}
Range
IntSet
::
GetCoverRange
()
const
{
template
<>
const
IntSetNode
*
s
=
operator
->
();
inline
IntSet
CombineSets
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
CHECK
(
s
!=
nullptr
)
<<
"empty set"
;
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
s
->
domain
.
size
()
==
0
&&
s
->
concrete
.
size
()
==
0
)
{
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
return
s
->
base
;
if
(
a_int
&&
is_zero
(
a_int
->
i
.
min
))
return
b
;
if
(
b_int
&&
is_zero
(
b_int
->
i
.
min
))
return
a
;
a
=
AsStrideSet
(
a
);
b
=
AsStrideSet
(
b
);
const
StrideSet
*
a_stride
=
a
.
as
<
StrideSet
>
();
const
StrideSet
*
b_stride
=
b
.
as
<
StrideSet
>
();
auto
n
=
std
::
make_shared
<
StrideSet
>
(
*
a_stride
);
for
(
size_t
i
=
0
;
i
<
b_stride
->
extents
.
size
();
++
i
)
{
n
->
extents
.
push_back
(
b_stride
->
extents
[
i
]);
n
->
strides
.
push_back
(
b_stride
->
strides
[
i
]);
}
}
LOG
(
FATAL
)
<<
"not yet implemented"
;
n
->
base
=
CombineInterval
<
Add
>
(
return
Range
();
a_stride
->
base
,
b_stride
->
base
).
as
<
IntervalSet
>
()
->
i
;
return
IntSet
(
n
);
}
}
IntSet
IntSet
::
make_point
(
Expr
point
)
{
inline
IntSet
NegateSet
(
IntSet
a
)
{
return
IntSet
::
make_range
(
Range
::
make_with_min_extent
(
point
,
1
));
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
a_int
)
{
if
(
a_int
->
i
.
is_single_point
())
{
return
IntSet
::
single_point
(
-
a_int
->
i
.
min
);
}
else
{
Interval
r
=
Interval
::
everything
();
if
(
a_int
->
i
.
has_upper_bound
())
{
r
.
min
=
-
(
a_int
->
i
.
max
);
}
if
(
a_int
->
i
.
has_lower_bound
())
{
r
.
max
=
-
(
a_int
->
i
.
min
);
}
return
IntervalSet
::
make
(
r
);
}
}
else
{
return
NegateSet
(
a
.
cover_interval
());
}
}
}
IntSet
IntSet
::
make_all_set
()
{
template
<>
LOG
(
FATAL
)
<<
"TODO"
;
inline
IntSet
CombineSets
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
return
IntSet
(
);
return
CombineSets
<
Add
>
(
a
,
NegateSet
(
b
)
);
}
}
IntSet
Union
(
const
Array
<
IntSet
>&
set
)
{
TVM_DECLARE_LOGICAL_OP
(
And
);
if
(
set
.
size
()
==
1
)
return
set
[
0
];
TVM_DECLARE_LOGICAL_OP
(
Or
);
LOG
(
FATAL
)
<<
"TODO"
;
TVM_DECLARE_LOGICAL_OP
(
EQ
);
return
IntSet
();
TVM_DECLARE_LOGICAL_OP
(
NE
);
TVM_DECLARE_LOGICAL_OP
(
GE
);
TVM_DECLARE_LOGICAL_OP
(
GT
);
TVM_DECLARE_LOGICAL_OP
(
LE
);
TVM_DECLARE_LOGICAL_OP
(
LT
);
TVM_DECLARE_LOGICAL_OP
(
Not
);
// generic combine operations of two sets
template
<
typename
OP
>
inline
IntSet
Combine
(
const
IntSet
&
a
,
const
IntSet
&
b
)
{
if
(
is_logical_op
<
OP
>::
value
)
{
return
IntervalSet
::
make
(
0
,
1
);
}
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
if
(
a_int
&&
a_int
->
i
.
is_everything
())
return
a
;
if
(
b_int
&&
b_int
->
i
.
is_everything
())
return
b
;
if
(
a_int
&&
b_int
)
{
return
CombineInterval
<
OP
>
(
a_int
->
i
,
b_int
->
i
);
}
if
(
a_int
&&
!
(
a_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
,
b
.
cover_interval
());
}
if
(
b_int
&&
!
(
b_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
);
}
return
CombineSets
<
OP
>
(
a
,
b
);
}
}
// Implementation of Evaluations and passing.
void
PassUp
(
const
SplitNode
*
s
,
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
outer
,
...
@@ -215,33 +358,21 @@ void PassUp(const SplitNode* s,
...
@@ -215,33 +358,21 @@ void PassUp(const SplitNode* s,
if
(
dom_map
.
count
(
s
->
outer
)
&&
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
dom_map
.
count
(
s
->
parent
)
&&
Match
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
Match
Range
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
Match
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
Match
Range
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
make_
range
(
dom_map
.
at
(
s
->
parent
));
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
return
;
}
}
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
outer
.
defined
());
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
CHECK
(
factor
.
defined
());
// copy construct
auto
n
=
std
::
make_shared
<
IntSetNode
>
(
*
(
inner
.
operator
->
()));
*
parent
=
Combine
<
Add
>
(
Combine
<
Add
>
(
if
(
IsNumber
(
outer
))
{
Combine
<
Mul
>
(
outer
,
IntSet
::
single_point
(
factor
)),
inner
),
// shift the base offset
IntSet
::
single_point
(
parent_min
));
n
->
base
=
Range
::
make_with_min_extent
(
AsNumber
(
outer
)
*
factor
+
inner
->
base
->
min
,
inner
->
base
->
extent
);
}
else
{
// default use all domains in the data.
n
->
domain
.
push_back
(
outer
->
base
);
n
->
stride
.
push_back
(
factor
);
for
(
size_t
i
=
0
;
i
<
outer
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
outer
->
domain
[
i
]);
n
->
stride
.
push_back
(
outer
->
stride
[
i
]
*
factor
);
}
}
*
parent
=
IntSet
(
n
);
}
}
void
PassUp
(
const
FuseNode
*
s
,
void
PassUp
(
const
FuseNode
*
s
,
...
@@ -253,29 +384,51 @@ void PassUp(const FuseNode* s,
...
@@ -253,29 +384,51 @@ void PassUp(const FuseNode* s,
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
Match
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
if
(
Match
Range
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
make_
range
(
dom_map
.
at
(
s
->
outer
));
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make_
range
(
dom_map
.
at
(
s
->
inner
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
return
;
}
}
if
(
IsNumber
(
fused
))
{
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
value
=
AsNumber
(
fused
);
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
const
IntervalSet
*
fused_int
=
fused
.
as
<
IntervalSet
>
();
if
(
fused_int
&&
fused_int
->
i
.
is_single_point
())
{
Expr
value
=
fused_int
->
i
.
min
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
*
outer
=
IntSet
::
make_point
(
value
/
factor
);
Expr
v_outer
=
value
/
factor
;
*
inner
=
IntSet
::
make_point
(
value
%
factor
);
Expr
v_inner
=
value
%
factor
;
if
(
!
is_zero
(
outer_min
))
v_outer
=
v_outer
+
outer_min
;
if
(
!
is_zero
(
inner_min
))
v_inner
=
v_inner
+
inner_min
;
*
outer
=
IntSet
::
single_point
(
v_outer
);
*
inner
=
IntSet
::
single_point
(
v_inner
);
}
else
{
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
make_range
(
dom_map
.
at
(
s
->
outer
));
*
outer
=
IntSet
::
range
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make_range
(
dom_map
.
at
(
s
->
inner
));
*
inner
=
IntSet
::
range
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
rebased
,
IntSet
*
parent
)
{
CHECK
(
dom_map
.
count
(
s
->
parent
));
if
(
MatchRange
(
rebased
,
dom_map
.
at
(
s
->
rebased
)))
{
*
parent
=
IntSet
::
range
(
dom_map
.
at
(
s
->
parent
));
return
;
return
;
}
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
Combine
<
Add
>
(
rebased
,
IntSet
::
single_point
(
parent_min
));
}
}
namespace
{
// Evaluator to evalute the epxression.
// evaluator to evaluate the int set
class
IntSetEvaluator
{
class
IRSetEvaluator
{
public
:
public
:
inline
IntSet
Eval
(
Expr
expr
)
{
inline
IntSet
Eval
(
Expr
expr
)
{
static
const
FType
&
f
=
vtable
();
static
const
FType
&
f
=
vtable
();
...
@@ -283,11 +436,11 @@ class IRSetEvaluator {
...
@@ -283,11 +436,11 @@ class IRSetEvaluator {
return
f
(
expr
,
expr
,
this
);
return
f
(
expr
,
expr
,
this
);
}
else
{
}
else
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
expr
->
type_key
();
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
expr
->
type_key
();
return
IntSet
::
make_all_set
();
return
IntSet
::
everything
();
}
}
}
}
using
FType
=
tvm
::
IRFunctor
<
IntSet
(
const
NodeRef
&
,
const
Expr
&
,
I
R
SetEvaluator
*
)
>
;
using
FType
=
tvm
::
IRFunctor
<
IntSet
(
const
NodeRef
&
,
const
Expr
&
,
I
nt
SetEvaluator
*
)
>
;
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
inst
;
return
inst
;
static
FType
inst
;
return
inst
;
}
}
...
@@ -295,76 +448,84 @@ class IRSetEvaluator {
...
@@ -295,76 +448,84 @@ class IRSetEvaluator {
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
};
};
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
I
R
SetEvaluator
*
)
{
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
I
nt
SetEvaluator
*
)
{
return
IntSet
::
mak
e_point
(
e
);
return
IntSet
::
singl
e_point
(
e
);
}
}
TVM_STATIC_IR_FUNCTOR
(
I
R
SetEvaluator
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
I
nt
SetEvaluator
,
vtable
)
.
set_dispatch
<
IntImm
>
(
ConstOp
)
.
set_dispatch
<
IntImm
>
(
ConstOp
)
.
set_dispatch
<
UIntImm
>
(
ConstOp
)
.
set_dispatch
<
UIntImm
>
(
ConstOp
)
.
set_dispatch
<
FloatImm
>
(
ConstOp
);
.
set_dispatch
<
FloatImm
>
(
ConstOp
);
TVM_STATIC_IR_FUNCTOR
(
I
R
SetEvaluator
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
I
nt
SetEvaluator
,
vtable
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
I
R
SetEvaluator
*
m
)
{
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
I
nt
SetEvaluator
*
m
)
{
auto
it
=
m
->
dom_map
.
find
(
op
);
auto
it
=
m
->
dom_map
.
find
(
op
);
if
(
it
!=
m
->
dom_map
.
end
())
{
if
(
it
!=
m
->
dom_map
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
else
{
}
else
{
return
IntSet
::
mak
e_point
(
e
);
return
IntSet
::
singl
e_point
(
e
);
}
}
});
});
// binary operator
// binary operator
template
<
typename
T
>
template
<
typename
T
>
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
,
I
R
SetEvaluator
*
m
)
{
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
,
I
nt
SetEvaluator
*
m
)
{
IntSet
a
=
m
->
Eval
(
op
->
a
);
IntSet
a
=
m
->
Eval
(
op
->
a
);
IntSet
b
=
m
->
Eval
(
op
->
b
);
IntSet
b
=
m
->
Eval
(
op
->
b
);
if
(
IsNumber
(
a
)
&&
IsNumber
(
b
))
{
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
if
(
Match
(
a
,
op
->
a
)
&&
return
IntSet
::
single_point
(
e
);
Match
(
b
,
op
->
b
))
{
return
IntSet
::
make_point
(
e
);
}
else
{
return
IntSet
::
make_point
(
T
::
make
(
AsNumber
(
a
),
AsNumber
(
b
)));
}
}
else
{
return
BinaryCombine
<
T
>
(
a
,
b
);
}
}
IntSet
r
=
Combine
<
T
>
(
a
,
b
);
return
r
;
}
}
TVM_STATIC_IR_FUNCTOR
(
I
R
SetEvaluator
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
I
nt
SetEvaluator
,
vtable
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
);
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
)
.
set_dispatch
<
EQ
>
(
Binary
<
EQ
>
)
// use simply bound for logical expressions for now.
.
set_dispatch
<
NE
>
(
Binary
<
NE
>
)
inline
IntSet
Logical
(
const
NodeRef
&
,
const
Expr
&
e
,
IRSetEvaluator
*
)
{
.
set_dispatch
<
LT
>
(
Binary
<
LT
>
)
return
IntSet
::
make_range
(
Range
::
make_with_min_extent
(
0
,
2
));
.
set_dispatch
<
LE
>
(
Binary
<
LE
>
)
}
.
set_dispatch
<
GT
>
(
Binary
<
GT
>
)
.
set_dispatch
<
GE
>
(
Binary
<
GE
>
)
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
.
set_dispatch
<
And
>
(
Binary
<
And
>
)
.
set_dispatch
<
EQ
>
(
Logical
)
.
set_dispatch
<
Or
>
(
Binary
<
Or
>
);
.
set_dispatch
<
NE
>
(
Logical
)
.
set_dispatch
<
LT
>
(
Logical
)
.
set_dispatch
<
LE
>
(
Logical
)
.
set_dispatch
<
GT
>
(
Logical
)
.
set_dispatch
<
GE
>
(
Logical
)
.
set_dispatch
<
And
>
(
Logical
)
.
set_dispatch
<
Or
>
(
Logical
);
}
// namespace
IntSet
EvalSet
(
Expr
e
,
IntSet
EvalSet
(
Expr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
I
R
SetEvaluator
m
;
I
nt
SetEvaluator
m
;
for
(
auto
kv
:
dom_map
)
{
for
(
auto
kv
:
dom_map
)
{
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
}
return
m
.
Eval
(
e
);
return
m
.
Eval
(
e
);
}
}
IntSet
EvalSet
(
Range
r
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
IntSetEvaluator
m
;
for
(
auto
kv
:
dom_map
)
{
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
IntSet
min_set
=
m
.
Eval
(
r
->
min
);
IntSet
ext_set
=
m
.
Eval
(
r
->
extent
).
cover_interval
();
const
Interval
&
ei
=
ext_set
.
as
<
IntervalSet
>
()
->
i
;
if
(
!
ei
.
has_upper_bound
())
return
IntSet
::
everything
();
ext_set
=
IntervalSet
::
make
(
0
,
ComputeExpr
<
Sub
>
(
ei
.
max
,
1
));
return
Combine
<
Add
>
(
min_set
,
ext_set
);
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IntervalSet
>
([](
const
IntervalSet
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"interval-set["
<<
"["
<<
op
->
i
.
min
<<
", "
<<
op
->
i
.
max
<<
']'
;
});
}
// namespace schedule
}
// namespace schedule
}
// namespace tvm
}
// namespace tvm
src/schedule/int_set.h
View file @
a2c8a29b
...
@@ -22,35 +22,48 @@ class IntSet : public NodeRef {
...
@@ -22,35 +22,48 @@ class IntSet : public NodeRef {
public
:
public
:
/*! \brief constructor */
/*! \brief constructor */
IntSet
()
{}
IntSet
()
{}
// constructor from not
de
ontainer.
// constructor from not
c
ontainer.
explicit
IntSet
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
explicit
IntSet
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*! \return whether the set is empty */
inline
bool
is_empty
()
const
{
return
!
defined
();
}
/*!
* \return a range that covers the IntSet
*/
Range
GetCoverRange
()
const
;
/*!
/*!
* \brief access the internal node container
* \brief access the internal node container
* \return the pointer to the internal node container
* \return the pointer to the internal node container
*/
*/
inline
const
IntSetNode
*
operator
->
()
const
;
inline
const
IntSetNode
*
operator
->
()
const
;
/*!
/*!
* \param dom The domain to be created.
* \brief Find a range that covers the region.
* \return create integer set from existing domain
* \param max_range The range to be covered.
* \return The covering range.
*/
Range
cover_range
(
Range
max_range
)
const
;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
*/
static
IntSet
make_range
(
Range
dom
);
IntSet
cover_interval
()
const
;
/*! \return Whether the set represent everything */
bool
is_everything
()
const
;
/*! \return Whether the set is a single point */
bool
is_single_point
()
const
;
/*! \return Whether the set contains everything */
static
IntSet
everything
();
/*!
/*!
* \param point
* \brief construct a point set.
* \return create integer set that only contains one point
* \param point The point in the set.
* \return construct a single point set
*/
*/
static
IntSet
mak
e_point
(
Expr
point
);
static
IntSet
singl
e_point
(
Expr
point
);
/*!
/*!
* \return create integer set that represents everything
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
*/
static
IntSet
make_all_set
();
static
IntSet
range
(
Range
r
);
};
/*!
* \brief Base class of all IntSet containers.
*/
struct
IntSetNode
:
public
Node
{
};
};
/*!
/*!
...
@@ -63,6 +76,18 @@ class IntSet : public NodeRef {
...
@@ -63,6 +76,18 @@ class IntSet : public NodeRef {
*/
*/
IntSet
EvalSet
(
Expr
e
,
IntSet
EvalSet
(
Expr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet
EvalSet
(
Range
r
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
/*!
* \brief Conditional upward message passing.
* \brief Conditional upward message passing.
*
*
...
@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
...
@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
const
IntSet
&
fused
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
outer
,
IntSet
*
inner
);
IntSet
*
inner
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void
PassUp
(
const
RebaseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
parent
);
/*!
/*!
* \brief Create an union set of all sets
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \param sets The sets to be unioned
...
@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
...
@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
*/
*/
IntSet
Union
(
const
Array
<
IntSet
>&
sets
);
IntSet
Union
(
const
Array
<
IntSet
>&
sets
);
// implementation
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
}
}
// namespace schedule
}
// namespace schedule
}
// namespace tvm
}
// namespace tvm
...
...
src/schedule/schedule_lang.cc
View file @
a2c8a29b
...
@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
...
@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
}
}
CHECK
(
found
)
CHECK
(
found
)
<<
"Cannot
compute at a iteration variable that is not part of parent leaf
vars"
;
<<
"Cannot
find the specified axis in parent stage's leaf_iter_
vars"
;
return
*
this
;
return
*
this
;
}
}
...
@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
...
@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return
*
this
;
return
*
this
;
}
}
Schedule
::
Schedule
(
Array
<
Operation
>
ops
)
{
Schedule
::
Schedule
(
Array
<
Operation
>
ops
)
{
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
roots
=
ops
;
n
->
roots
=
ops
;
...
@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
...
@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
return
IterVarRelation
(
n
);
return
IterVarRelation
(
n
);
}
}
IterVarRelation
RebaseNode
::
make
(
IterVar
parent
,
IterVar
rebased
)
{
auto
n
=
std
::
make_shared
<
RebaseNode
>
();
n
->
parent
=
parent
;
n
->
rebased
=
rebased
;
return
IterVarRelation
(
n
);
}
void
Schedule
::
normalize
()
{
std
::
unordered_map
<
IterVar
,
IterVar
>
rebase_map
;
std
::
unordered_map
<
const
Node
*
,
int
>
attach_mark
;
for
(
Stage
s
:
(
*
this
)
->
stages
)
{
if
(
s
->
attach_type
==
kScope
)
{
attach_mark
[
s
->
attach_stage
.
get
()]
=
1
;
}
}
for
(
Stage
s
:
(
*
this
)
->
stages
)
{
if
(
!
attach_mark
.
count
(
s
.
get
()))
continue
;
auto
root_iter_vars
=
s
->
op
->
root_iter_vars
();
ArrayNode
*
leaf_vars
=
s
->
leaf_iter_vars
.
CopyOnWrite
();
for
(
IterVar
iv
:
root_iter_vars
)
{
size_t
idx
=
FindIterVar
(
leaf_vars
,
iv
);
if
(
idx
<
leaf_vars
->
data
.
size
())
{
// insert rebase
IterVar
rebased
(
Range
(),
iv
->
var
->
name_hint
+
".rb"
);
s
->
relations
.
push_back
(
RebaseNode
::
make
(
iv
,
rebased
));
leaf_vars
->
data
[
idx
]
=
rebased
.
node_
;
rebase_map
[
iv
]
=
rebased
;
}
}
}
// remap the parent relation
for
(
Stage
s
:
(
*
this
)
->
stages
)
{
if
(
s
->
attach_type
!=
kScope
)
continue
;
if
(
rebase_map
.
count
(
s
->
attach_ivar
))
{
s
->
attach_ivar
=
rebase_map
.
at
(
s
->
attach_ivar
);
}
}
}
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
RebaseNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
}
// namespace tvm
}
// namespace tvm
src/
pass
/schedule_ops.cc
→
src/
schedule
/schedule_ops.cc
View file @
a2c8a29b
...
@@ -8,12 +8,44 @@
...
@@ -8,12 +8,44 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include <tvm/schedule_pass.h>
#include ".
/scope
.h"
#include ".
./pass/ir_util
.h"
#include "./i
r_util
.h"
#include "./i
nt_set
.h"
#include ".
./schedule
/graph.h"
#include "./graph.h"
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
schedule
{
using
namespace
ir
;
/*!
* \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void
PassDownFlag
(
const
Stage
&
s
,
std
::
unordered_map
<
IterVar
,
int
>*
p_state
)
{
auto
&
state
=
*
p_state
;
for
(
IterVarRelation
rel
:
s
->
relations
)
{
if
(
rel
.
as
<
SplitNode
>
())
{
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
();
int
flag
=
state
.
at
(
s
->
parent
);
state
[
s
->
outer
]
=
flag
;
state
[
s
->
inner
]
=
flag
;
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
int
flag_outer
=
state
.
at
(
s
->
outer
);
int
flag_inner
=
state
.
at
(
s
->
inner
);
state
[
s
->
fused
]
=
flag_outer
|
flag_inner
;
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
int
flag
=
state
.
at
(
s
->
parent
);
state
[
s
->
rebased
]
=
flag
;
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
}
}
/*!
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
* \brief use message passing to calculate the assignment of each Var inside the loop body.
...
@@ -37,7 +69,7 @@ void PassUpOffset(const Stage& s,
...
@@ -37,7 +69,7 @@ void PassUpOffset(const Stage& s,
state
[
s
->
parent
]
=
inner
+
outer
*
factor
;
state
[
s
->
parent
]
=
inner
+
outer
*
factor
;
// add min if they exist
// add min if they exist
if
(
!
is_zero
(
parent_min
))
{
if
(
!
is_zero
(
parent_min
))
{
state
[
s
->
parent
]
=
parent_min
+
state
[
s
->
parent
]
;
state
[
s
->
parent
]
=
state
[
s
->
parent
]
+
parent_min
;
}
}
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
...
@@ -49,10 +81,20 @@ void PassUpOffset(const Stage& s,
...
@@ -49,10 +81,20 @@ void PassUpOffset(const Stage& s,
state
[
s
->
inner
]
=
value
%
factor
;
state
[
s
->
inner
]
=
value
%
factor
;
// add min if they exist
// add min if they exist
if
(
!
is_zero
(
outer_min
))
{
if
(
!
is_zero
(
outer_min
))
{
state
[
s
->
outer
]
=
outer_min
+
state
[
s
->
outer
]
;
state
[
s
->
outer
]
=
state
[
s
->
outer
]
+
outer_min
;
}
}
if
(
!
is_zero
(
inner_min
))
{
if
(
!
is_zero
(
inner_min
))
{
state
[
s
->
inner
]
=
outer_min
+
state
[
s
->
inner
];
state
[
s
->
inner
]
=
state
[
s
->
inner
]
+
inner_min
;
}
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
Expr
value
=
state
.
at
(
s
->
rebased
);
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
// add min if they exist
if
(
!
is_zero
(
parent_min
))
{
state
[
s
->
parent
]
=
value
+
parent_min
;
}
else
{
state
[
s
->
parent
]
=
value
;
}
}
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
LOG
(
FATAL
)
<<
"unknown relation type"
;
...
@@ -60,76 +102,54 @@ void PassUpOffset(const Stage& s,
...
@@ -60,76 +102,54 @@ void PassUpOffset(const Stage& s,
}
}
}
}
/*!
std
::
vector
<
std
::
vector
<
Stmt
>
>
* \brief split the expr by addition.
MakeLoopNest
(
const
Stage
&
sch
,
* \param expr The expression to be splitted.
const
Map
<
IterVar
,
Range
>&
dom_map
,
* \param loop_level The loop level of each Variable
size_t
begin_loop
,
* \param result vector of (level, expr)
bool
reduce_init_loop
,
* The level gives the mimimum loop level this expression need to be computed.
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
* The Expr gives the expression content.
const
std
::
unordered_map
<
IterVar
,
bool
>&
skip_iter
)
{
*/
void
SplitByAdd
(
Expr
expr
,
const
std
::
unordered_map
<
const
Variable
*
,
size_t
>&
loop_level
,
std
::
vector
<
std
::
pair
<
size_t
,
Expr
>
>
*
result
)
{
const
Add
*
op
=
expr
.
as
<
Add
>
();
if
(
op
!=
nullptr
)
{
SplitByAdd
(
op
->
a
,
loop_level
,
result
);
SplitByAdd
(
op
->
b
,
loop_level
,
result
);
}
else
{
size_t
max_level
=
0
;
auto
fvisit
=
[
&
max_level
,
&
loop_level
](
const
NodeRef
&
n
)
{
const
Variable
*
op
=
n
.
as
<
Variable
>
();
if
(
op
!=
nullptr
)
{
auto
it
=
loop_level
.
find
(
op
);
if
(
it
!=
loop_level
.
end
())
{
max_level
=
std
::
max
(
max_level
,
it
->
second
);
}
}
};
PostOrderVisit
(
expr
,
fvisit
);
result
->
push_back
(
std
::
make_pair
(
max_level
,
expr
));
}
}
/*!
* \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule.
* \param dom_map The domain map.
*
* \return a nested representation of loop statements.
* The flattened Stmt are ordered from outmost to inner most order.
*/
std
::
vector
<
std
::
vector
<
Stmt
>
>
MakeLoopNest
(
const
Stage
&
sch
,
const
Map
<
IterVar
,
Range
>&
dom_map
)
{
// optional, use let to define some CSE in dom_map.
auto
leaf_iter_vars
=
sch
->
leaf_iter_vars
;
auto
leaf_iter_vars
=
sch
->
leaf_iter_vars
;
std
::
unordered_map
<
IterVar
,
Expr
>
offset
;
std
::
unordered_map
<
const
Variable
*
,
size_t
>
loop_level
;
Stmt
no_op
=
Evaluate
::
make
(
0
);
Stmt
no_op
=
Evaluate
::
make
(
0
);
// create the loop nest
// create the loop nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
nest
;
std
::
vector
<
std
::
vector
<
Stmt
>
>
nest
;
nest
.
resize
(
leaf_iter_vars
.
size
()
+
1
);
nest
.
resize
(
leaf_iter_vars
.
size
()
+
1
);
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
=
*
p_value_map
;
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
for
(
size_t
i
=
begin_loop
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
auto
iv
=
leaf_iter_vars
[
i
];
if
(
skip_iter
.
count
(
iv
)
&&
skip_iter
.
at
(
iv
))
{
// skip this iteration.
value_map
[
iv
]
=
iv
->
var
;
continue
;
}
Range
dom
=
dom_map
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
// initialize the offset and loop_level
// initialize the offset and loop_level
offset
[
iv
]
=
iv
->
var
;
Var
var
=
iv
->
var
;
loop_level
[
iv
->
var
.
as
<
Variable
>
()]
=
i
+
1
;
if
(
reduce_init_loop
)
{
var
=
Var
(
iv
->
var
->
name_hint
+
".init"
,
iv
->
var
.
type
());
}
// Mark the iter var in the IR, to remember the point
// Mark the iter var in the IR, to remember the point
if
(
iv
->
thread_tag
.
length
()
==
0
)
{
if
(
iv
->
thread_tag
.
length
()
==
0
)
{
if
(
is_zero
(
dom
->
min
))
{
if
(
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
value_map
[
iv
]
=
dom
->
min
;
}
else
if
(
is_zero
(
dom
->
min
))
{
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
For
::
make
(
iv
->
var
,
0
,
dom
->
extent
,
For
::
make
(
var
,
0
,
dom
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
value_map
[
iv
]
=
var
;
}
else
{
}
else
{
Var
idx
(
iv
->
var
->
name_hint
+
".idx"
,
iv
->
var
.
type
());
Var
idx
(
iv
->
var
->
name_hint
+
".idx"
,
iv
->
var
.
type
());
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
For
::
make
(
idx
,
0
,
dom
->
extent
,
For
::
make
(
idx
,
0
,
dom
->
extent
,
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
ForType
::
Serial
,
DeviceAPI
::
None
,
no_op
));
Expr
new_value
=
dom
->
min
+
idx
;
value_map
[
iv
]
=
new_value
;
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
iv
->
var
,
dom
->
min
+
idx
,
no_op
));
LetStmt
::
make
(
var
,
new_value
,
no_op
));
}
}
}
else
{
}
else
{
// Always restrict threaded IterVar to starts from 0.
// Always restrict threaded IterVar to starts from 0.
...
@@ -137,69 +157,73 @@ std::vector<std::vector<Stmt> > MakeLoopNest(
...
@@ -137,69 +157,73 @@ std::vector<std::vector<Stmt> > MakeLoopNest(
// annotate the extent of the IterVar
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"thread_extent"
,
dom
->
extent
,
no_op
));
AttrStmt
::
make
(
iv
,
"thread_extent"
,
dom
->
extent
,
no_op
));
value_map
[
iv
]
=
var
;
}
if
(
!
reduce_init_loop
)
{
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"scope"
,
iv
->
var
,
no_op
));
}
}
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
"scope"
,
iv
->
var
,
no_op
));
}
}
// message passing to get offset of root iter vars.
// message passing to get offset of root iter vars.
PassUpOffset
(
sch
,
dom_map
,
&
offset
);
PassUpOffset
(
sch
,
dom_map
,
&
value_map
);
return
nest
;
for
(
IterVar
iv
:
sch
->
op
->
root_iter_vars
())
{
}
Expr
value
=
offset
.
at
(
iv
);
if
(
!
value
.
same_as
(
iv
->
var
))
{
using
Entry
=
std
::
pair
<
size_t
,
Expr
>
;
std
::
vector
<
Entry
>
splits
;
SplitByAdd
(
value
,
loop_level
,
&
splits
);
Expr
offset
=
0
;
Stmt
MakeLoop
(
const
Stage
&
s
,
size_t
nsplit_left
=
splits
.
size
()
-
1
;
const
Map
<
IterVar
,
Range
>&
dom_map
,
for
(
size_t
i
=
0
;
i
<=
leaf_iter_vars
.
size
();
++
i
)
{
Stmt
provide
,
size_t
hit
=
0
;
Stmt
init
)
{
for
(
const
auto
&
kv
:
splits
)
{
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
if
(
kv
.
first
==
i
)
{
auto
nest
=
MakeLoopNest
(
s
,
dom_map
,
0
,
false
,
&
value_map
,
{});
if
(
is_zero
(
offset
))
{
provide
=
Substitute
(
provide
,
value_map
);
offset
=
kv
.
second
;
if
(
init
.
defined
())
{
}
else
{
// try to find the location to insert the initialization.
offset
=
offset
+
kv
.
second
;
// Fuse the initialization and provide loop when possible.
++
hit
;
std
::
unordered_map
<
IterVar
,
int
>
reduce_state
;
}
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
}
for
(
IterVar
iv
:
compute
->
reduce_axis
)
{
}
reduce_state
[
iv
]
=
2
;
nsplit_left
-=
hit
;
}
if
(
hit
!=
0
)
{
for
(
IterVar
iv
:
compute
->
axis
)
{
std
::
ostringstream
os
;
reduce_state
[
iv
]
=
1
;
os
<<
iv
->
var
->
name_hint
<<
".at.l"
<<
i
;
}
Var
base_offset
(
os
.
str
());
// find which iter var is related to reduction and which is related to axis.
if
(
nsplit_left
==
0
)
{
PassDownFlag
(
s
,
&
reduce_state
);
base_offset
=
iv
->
var
;
auto
leaf_iter_vars
=
s
->
leaf_iter_vars
;
}
std
::
unordered_map
<
IterVar
,
Expr
>
init_value_map
;
nest
[
i
].
emplace_back
(
// first first loop that is related to reduction.
LetStmt
::
make
(
base_offset
,
offset
,
no_op
));
size_t
begin_loop
=
leaf_iter_vars
.
size
();
offset
=
base_offset
;
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
}
auto
iv
=
leaf_iter_vars
[
i
];
}
int
flag
=
reduce_state
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
if
((
flag
&
2
)
!=
0
)
{
if
(
!
offset
.
same_as
(
iv
->
var
))
{
begin_loop
=
i
;
break
;
// define the iv->var
nest
.
back
().
emplace_back
(
LetStmt
::
make
(
iv
->
var
,
offset
,
no_op
));
}
}
Expr
condition
=
(
iv
->
var
-
dom
->
min
)
<
dom
->
extent
;
init_value_map
[
iv
]
=
value_map
.
at
(
iv
);
// Boundary condition checking
// Need better boundary condition here.
nest
.
back
().
emplace_back
(
IfThenElse
::
make
(
condition
,
no_op
));
}
}
// skip loops that does not relates to axis.
std
::
unordered_map
<
IterVar
,
bool
>
skip_iter
;
for
(
size_t
i
=
begin_loop
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
int
flag
=
reduce_state
.
at
(
iv
);
if
((
flag
&
1
)
==
0
)
skip_iter
[
iv
]
=
true
;
}
auto
init_nest
=
MakeLoopNest
(
s
,
dom_map
,
begin_loop
,
true
,
&
init_value_map
,
skip_iter
);
init
=
Substitute
(
init
,
init_value_map
);
init
=
MergeNest
(
init_nest
,
init
);
// common nest
std
::
vector
<
std
::
vector
<
Stmt
>
>
common
(
nest
.
begin
(),
nest
.
begin
()
+
begin_loop
);
std
::
vector
<
std
::
vector
<
Stmt
>
>
reduce
(
nest
.
begin
()
+
begin_loop
,
nest
.
end
());
provide
=
MergeNest
(
reduce
,
provide
);
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
}
else
{
return
MergeNest
(
nest
,
provide
);
}
}
return
nest
;
}
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param tensors The tensors generated by provide.
*/
Stmt
MakeProvide
(
const
ComputeOpNode
*
op
,
Stmt
MakeProvide
(
const
ComputeOpNode
*
op
,
const
std
::
vector
<
Tensor
>&
tensors
)
{
const
std
::
vector
<
Tensor
>&
tensors
)
{
Tensor
t
=
tensors
[
0
];
Tensor
t
=
tensors
[
0
];
...
@@ -210,13 +234,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
...
@@ -210,13 +234,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
return
Provide
::
make
(
t
->
op
,
t
->
value_index
,
op
->
body
,
args
);
return
Provide
::
make
(
t
->
op
,
t
->
value_index
,
op
->
body
,
args
);
}
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param dom_map The domain map
* \param tensors The tensors generated by provide.
* \param body The content of the pipeline.
*/
Stmt
MakeRealize
(
const
ComputeOpNode
*
op
,
Stmt
MakeRealize
(
const
ComputeOpNode
*
op
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
vector
<
Tensor
>&
tensors
,
const
std
::
vector
<
Tensor
>&
tensors
,
...
@@ -230,6 +247,38 @@ Stmt MakeRealize(const ComputeOpNode* op,
...
@@ -230,6 +247,38 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds
,
make_const
(
Bool
(
1
),
true
),
body
);
bounds
,
make_const
(
Bool
(
1
),
true
),
body
);
}
}
void
MakeReduction
(
const
ComputeOpNode
*
op
,
const
std
::
vector
<
Tensor
>&
tensors
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
*
init
,
Stmt
*
provide
)
{
Stmt
no_op
=
Evaluate
::
make
(
0
);
Tensor
t
=
tensors
[
0
];
std
::
vector
<
Stmt
>
nest
;
Array
<
Expr
>
args
;
for
(
IterVar
iv
:
op
->
axis
)
{
args
.
push_back
(
iv
->
var
);
}
const
Reduce
*
reduce
=
op
->
body
.
as
<
Reduce
>
();
CHECK
(
reduce
);
Expr
init_value
,
update_value
;
if
(
reduce
->
op
==
"Add"
)
{
init_value
=
make_zero
(
reduce
->
type
);
update_value
=
Add
::
make
(
t
(
args
),
reduce
->
source
);
}
else
if
(
reduce
->
op
==
"Max"
)
{
init_value
=
reduce
->
type
.
min
();
update_value
=
Max
::
make
(
t
(
args
),
reduce
->
source
);
}
else
if
(
reduce
->
op
==
"Min"
)
{
init_value
=
reduce
->
type
.
max
();
update_value
=
Min
::
make
(
t
(
args
),
reduce
->
source
);
}
else
{
LOG
(
FATAL
)
<<
"Unsupported reduction "
<<
reduce
->
op
;
}
*
init
=
Provide
::
make
(
t
->
op
,
t
->
value_index
,
init_value
,
args
);
*
provide
=
Provide
::
make
(
t
->
op
,
t
->
value_index
,
update_value
,
args
);
}
Stmt
MakePipeline
(
const
Stage
&
sch
,
Stmt
MakePipeline
(
const
Stage
&
sch
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
const
Map
<
IterVar
,
Range
>&
dom_map
,
Stmt
consumer
)
{
Stmt
consumer
)
{
...
@@ -238,14 +287,20 @@ Stmt MakePipeline(const Stage& sch,
...
@@ -238,14 +287,20 @@ Stmt MakePipeline(const Stage& sch,
tensors
.
emplace_back
(
sch
->
op
.
output
(
i
));
tensors
.
emplace_back
(
sch
->
op
.
output
(
i
));
}
}
Stmt
provide
;
Stmt
init
,
provide
;
if
(
sch
->
op
.
as
<
ComputeOpNode
>
())
{
provide
=
MakeProvide
(
sch
->
op
.
as
<
ComputeOpNode
>
(),
tensors
);
const
ComputeOpNode
*
compute
=
sch
->
op
.
as
<
ComputeOpNode
>
();
if
(
compute
)
{
if
(
compute
->
reduce_axis
.
size
()
==
0
)
{
provide
=
MakeProvide
(
compute
,
tensors
);
}
else
{
MakeReduction
(
compute
,
tensors
,
dom_map
,
&
init
,
&
provide
);
}
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"not supported op "
<<
sch
->
op
->
type_key
();
LOG
(
FATAL
)
<<
"not supported op "
<<
sch
->
op
->
type_key
();
}
}
std
::
vector
<
std
::
vector
<
Stmt
>
>
nest
=
MakeLoopNest
(
sch
,
dom_map
);
Stmt
producer
=
M
ergeNest
(
nest
,
provide
);
Stmt
producer
=
M
akeLoop
(
sch
,
dom_map
,
provide
,
init
);
producer
=
ProducerConsumer
::
make
(
sch
->
op
,
true
,
producer
);
producer
=
ProducerConsumer
::
make
(
sch
->
op
,
true
,
producer
);
Stmt
pipeline
=
producer
;
Stmt
pipeline
=
producer
;
...
@@ -306,7 +361,6 @@ Stmt InjectInline(const Operation op, Stmt body) {
...
@@ -306,7 +361,6 @@ Stmt InjectInline(const Operation op, Stmt body) {
return
Inline
(
body
,
op
,
args
,
compute
->
body
);
return
Inline
(
body
,
op
,
args
,
compute
->
body
);
}
}
Stmt
ScheduleOps
(
Stmt
ScheduleOps
(
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map
)
{
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map
)
{
Stmt
body
=
Stmt
();
Stmt
body
=
Stmt
();
...
@@ -330,5 +384,5 @@ Stmt ScheduleOps(
...
@@ -330,5 +384,5 @@ Stmt ScheduleOps(
return
body
;
return
body
;
}
}
}
// namespace
ir
}
// namespace
schedule
}
// namespace tvm
}
// namespace tvm
tests/python/integration/test_ewise.py
View file @
a2c8a29b
...
@@ -18,7 +18,8 @@ def test_add():
...
@@ -18,7 +18,8 @@ def test_add():
# one line to build the function.
# one line to build the function.
codes
=
[]
codes
=
[]
fadd
=
tvm
.
build
(
s
,
args
=
[
A
,
B
,
C
],
fadd
=
tvm
.
build
(
s
,
args
=
[
A
,
B
,
C
],
target
=
"cuda"
,
name
=
"myadd"
,
target
=
"cuda"
,
name
=
"myadd"
,
record_codes
=
codes
)
record_codes
=
codes
)
for
c
in
codes
:
for
c
in
codes
:
...
...
tests/python/integration/test_reduce.py
0 → 100644
View file @
a2c8a29b
import
tvm
import
numpy
as
np
def
test_sum
():
# graph
n
=
tvm
.
Var
(
'n'
)
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
k
=
tvm
.
IterVar
((
0
,
m
))
B
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
sum
(
A
[
i
,
k
],
axis
=
k
),
name
=
'B'
)
# schedule
s
=
tvm
.
Schedule
(
B
.
op
)
# create iter var and assign them tags.
num_thread
=
1
block_x
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.x"
)
thread_x
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.x"
)
_
,
x
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
num_thread
,
outer
=
block_x
)
_
,
x
=
s
[
B
]
.
split
(
x
,
outer
=
thread_x
)
tvm
.
init_opencl
()
codes
=
[]
fsum
=
tvm
.
build
(
s
,
args
=
[
A
,
B
],
target
=
"opencl"
,
name
=
"myadd"
,
record_codes
=
codes
)
for
c
in
codes
:
print
(
c
)
num_device
=
1
for
i
in
range
(
num_device
):
ctx
=
tvm
.
opencl
(
i
)
if
not
ctx
.
enabled
:
continue
# launch the kernel.
n
=
1028
m
=
129
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
m
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
B
.
dtype
),
ctx
)
fsum
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
np
.
sum
(
a
.
asnumpy
(),
axis
=
1
),
rtol
=
1e-4
)
if
__name__
==
"__main__"
:
test_sum
()
tests/python/unittest/test_codegen_device.py
View file @
a2c8a29b
...
@@ -18,8 +18,7 @@ def test_add_pipeline():
...
@@ -18,8 +18,7 @@ def test_add_pipeline():
# compile to IR
# compile to IR
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
...
...
tests/python/unittest/test_codegen_makeapi.py
View file @
a2c8a29b
...
@@ -10,12 +10,13 @@ def test_makeapi():
...
@@ -10,12 +10,13 @@ def test_makeapi():
s
=
tvm
.
Schedule
(
C
.
op
)
s
=
tvm
.
Schedule
(
C
.
op
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
B
:
Bb
,
C
:
Cb
})
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
B
:
Bb
,
C
:
Cb
})
num_packed_args
=
2
num_packed_args
=
2
f
=
tvm
.
codegen
.
MakeAPI
(
stmt
,
"myadd"
,
[
n
,
Ab
,
Bb
,
Cb
],
num_packed_args
)
f
=
tvm
.
codegen
.
MakeAPI
(
stmt
,
"myadd"
,
[
n
,
Ab
,
Bb
,
Cb
],
num_packed_args
)
assert
(
f
.
handle_data_type
[
Ab
.
data
]
.
dtype
==
Ab
.
dtype
)
assert
(
f
.
handle_data_type
[
Ab
.
data
]
.
dtype
==
Ab
.
dtype
)
...
...
tests/python/unittest/test_lang_tensor.py
View file @
a2c8a29b
...
@@ -26,7 +26,7 @@ def test_tensor_reduce():
...
@@ -26,7 +26,7 @@ def test_tensor_reduce():
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
])
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
])
rv
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
rv
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
rdom
=
rv
))
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
axis
=
rv
))
# json load save
# json load save
C_json
=
tvm
.
save_json
(
C
)
C_json
=
tvm
.
save_json
(
C
)
C_loaded
=
tvm
.
load_json
(
C_json
)
C_loaded
=
tvm
.
load_json
(
C_json
)
...
...
tests/python/unittest/test_pass_storage_flatten.py
View file @
a2c8a29b
...
@@ -12,7 +12,7 @@ def test_flatten2():
...
@@ -12,7 +12,7 @@ def test_flatten2():
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
print
(
stmt
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
...
...
tests/python/unittest/test_
pass
_schedule_ops.py
→
tests/python/unittest/test_
schedule
_schedule_ops.py
View file @
a2c8a29b
...
@@ -11,7 +11,7 @@ def test_schedule0():
...
@@ -11,7 +11,7 @@ def test_schedule0():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
print
(
stmt
)
def
test_schedule1
():
def
test_schedule1
():
...
@@ -24,7 +24,7 @@ def test_schedule1():
...
@@ -24,7 +24,7 @@ def test_schedule1():
xo
,
xi
=
s
[
A1
]
.
split
(
A1
.
op
.
axis
[
0
],
8
)
xo
,
xi
=
s
[
A1
]
.
split
(
A1
.
op
.
axis
[
0
],
8
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
print
(
stmt
)
def
test_schedule2
():
def
test_schedule2
():
...
@@ -39,7 +39,7 @@ def test_schedule2():
...
@@ -39,7 +39,7 @@ def test_schedule2():
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
print
(
stmt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment