Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d114dfc9
Commit
d114dfc9
authored
Feb 16, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 16, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Mutate dataflow in schedule, refactor Stage (#44)
parent
820a8597
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
542 additions
and
149 deletions
+542
-149
include/tvm/operation.h
+11
-11
include/tvm/schedule.h
+62
-8
python/tvm/build.py
+0
-1
python/tvm/schedule.py
+60
-0
src/api/api_lang.cc
+19
-1
src/lang/operation.cc
+3
-3
src/schedule/auto_inline_elem_wise.cc
+5
-18
src/schedule/bound.cc
+9
-3
src/schedule/schedule_lang.cc
+248
-32
src/schedule/schedule_ops.cc
+95
-42
tests/cpp/tensor_test.cc
+6
-6
tests/python/integration/test_gemm.py
+7
-22
tests/python/unittest/test_lang_schedule.py
+2
-2
tests/python/unittest/test_schedule_schedule_ops.py
+15
-0
No files found.
include/tvm/operation.h
View file @
d114dfc9
...
@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
...
@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
* \param name The name of the Tensor.
*/
*/
Tensor
P
laceholder
(
Array
<
Expr
>
shape
,
Tensor
p
laceholder
(
Array
<
Expr
>
shape
,
Type
dtype
=
Float
(
32
),
Type
dtype
=
Float
(
32
),
std
::
string
name
=
"placeholder"
);
std
::
string
name
=
"placeholder"
);
...
@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
...
@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
*/
*/
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
/*!
/*!
* \brief Construct new tensors by scan over scan_axis.
* \brief Construct new tensors by scan over scan_axis.
...
@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
...
@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
*/
*/
Array
<
Tensor
>
S
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
s
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
state_placeholder
,
std
::
string
name
=
"scan"
);
std
::
string
name
=
"scan"
);
// same as compute, specialized for different fcompute function
// same as compute, specialized for different fcompute function
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
}
// namespace tvm
}
// namespace tvm
...
...
include/tvm/schedule.h
View file @
d114dfc9
...
@@ -132,6 +132,13 @@ class Stage : public NodeRef {
...
@@ -132,6 +132,13 @@ class Stage : public NodeRef {
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
Expr
x_factor
,
Expr
y_factor
);
Expr
x_factor
,
Expr
y_factor
);
/*!
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage
&
outermost_threads
(
Array
<
IterVar
>
threads
);
/*!
* \brief Vectorize iteration.
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \param var The axis to be vectorized.
* \return reference to self.
* \return reference to self.
...
@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
...
@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
return
this
->
operator
[](
tensor
->
op
);
return
this
->
operator
[](
tensor
->
op
);
}
}
/*!
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor
cache_read
(
const
Tensor
&
tensor
,
const
std
::
string
&
scope
,
const
Array
<
Operation
>&
readers
);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor
cache_write
(
const
Tensor
&
tensor
,
const
std
::
string
&
scope
);
/*!
* \brief Normalize the schedule.
* \brief Normalize the schedule.
* This is needed before bound inference.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* Insert necessary RebaseNode to make sure all leaf_iter_vars
...
@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
...
@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
* \return the pointer to the internal node container
*/
*/
inline
const
ScheduleNode
*
operator
->
()
const
;
inline
const
ScheduleNode
*
operator
->
()
const
;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
ScheduleNode
*
operator
->
();
// declare container type
// declare container type
using
ContainerType
=
ScheduleNode
;
using
ContainerType
=
ScheduleNode
;
};
};
...
@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
...
@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
*/
*/
class
StageNode
:
public
Node
{
class
StageNode
:
public
Node
{
public
:
public
:
/*! \brief The operation to be scheduled */
Operation
op
;
/*! \brief The thread scope level of the stage */
/*! \brief The thread scope level of the stage */
std
::
string
scope
;
std
::
string
scope
;
/*! \brief The operation of stage, can be different from original op. */
Operation
op
;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation
origin_op
;
/*! \brief All the nodes in the iter var */
/*! \brief All the nodes in the iter var */
Array
<
IterVar
>
all_iter_vars
;
Array
<
IterVar
>
all_iter_vars
;
/*!
/*!
...
@@ -255,6 +295,11 @@ class StageNode : public Node {
...
@@ -255,6 +295,11 @@ class StageNode : public Node {
* Operations can only be performed in leaves.
* Operations can only be performed in leaves.
*/
*/
Array
<
IterVar
>
leaf_iter_vars
;
Array
<
IterVar
>
leaf_iter_vars
;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array
<
IterVar
>
outermost_threads
;
/*! \brief The relation bwteen of IterVars */
/*! \brief The relation bwteen of IterVars */
Array
<
IterVarRelation
>
relations
;
Array
<
IterVarRelation
>
relations
;
/*! \brief additional attributes about iter var. */
/*! \brief additional attributes about iter var. */
...
@@ -265,17 +310,22 @@ class StageNode : public Node {
...
@@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar
attach_ivar
;
IterVar
attach_ivar
;
/*! \brief The stage this node attaches to */
/*! \brief The stage this node attaches to */
Stage
attach_stage
;
Stage
attach_stage
;
/*! \brief Whether this is an output stage */
bool
is_output
{
false
};
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"origin_op"
,
&
origin_op
);
v
->
Visit
(
"all_iter_vars"
,
&
all_iter_vars
);
v
->
Visit
(
"all_iter_vars"
,
&
all_iter_vars
);
v
->
Visit
(
"leaf_iter_vars"
,
&
leaf_iter_vars
);
v
->
Visit
(
"leaf_iter_vars"
,
&
leaf_iter_vars
);
v
->
Visit
(
"outermost_threads"
,
&
outermost_threads
);
v
->
Visit
(
"relations"
,
&
relations
);
v
->
Visit
(
"relations"
,
&
relations
);
v
->
Visit
(
"iter_var_attrs"
,
&
iter_var_attrs
);
v
->
Visit
(
"iter_var_attrs"
,
&
iter_var_attrs
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_ivar"
,
&
attach_ivar
);
v
->
Visit
(
"attach_ivar"
,
&
attach_ivar
);
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
v
->
Visit
(
"is_output"
,
&
is_output
);
}
}
static
constexpr
const
char
*
_type_key
=
"Stage"
;
static
constexpr
const
char
*
_type_key
=
"Stage"
;
...
@@ -285,18 +335,18 @@ class StageNode : public Node {
...
@@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
/*! \brief node container for schedule */
class
ScheduleNode
:
public
Node
{
class
ScheduleNode
:
public
Node
{
public
:
public
:
/*! \brief The
root operations
*/
/*! \brief The
output operations in original data flow graph
*/
Array
<
Operation
>
roo
ts
;
Array
<
Operation
>
outpu
ts
;
/*!
/*!
* \brief list of all stages for non-placeholder ops
* \brief list of all stages for non-placeholder ops
.
*
The stage are ordered in PostDFS order of their op
.
*
The stages are sorted in dependency order
.
*/
*/
Array
<
Stage
>
stages
;
Array
<
Stage
>
stages
;
/*! \brief map of operation to the stages */
/*! \brief map of operation to the stages */
Map
<
Operation
,
Stage
>
stage_map
;
Map
<
Operation
,
Stage
>
stage_map
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"
roots"
,
&
roo
ts
);
v
->
Visit
(
"
outputs"
,
&
outpu
ts
);
v
->
Visit
(
"stages"
,
&
stages
);
v
->
Visit
(
"stages"
,
&
stages
);
v
->
Visit
(
"stage_map"
,
&
stage_map
);
v
->
Visit
(
"stage_map"
,
&
stage_map
);
}
}
...
@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
...
@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
inline
bool
Stage
::
is_scheduled
()
const
{
inline
bool
Stage
::
is_scheduled
()
const
{
const
StageNode
*
n
=
operator
->
();
const
StageNode
*
n
=
operator
->
();
return
!
(
n
->
relations
.
empty
()
&&
n
->
attach_type
==
kNone
);
return
!
(
n
->
relations
.
empty
()
&&
n
->
attach_type
==
kNone
&&
n
->
all_iter_vars
.
same_as
(
n
->
leaf_iter_vars
));
}
}
inline
const
ScheduleNode
*
Schedule
::
operator
->
()
const
{
inline
const
ScheduleNode
*
Schedule
::
operator
->
()
const
{
return
static_cast
<
const
ScheduleNode
*>
(
node_
.
get
());
return
static_cast
<
const
ScheduleNode
*>
(
node_
.
get
());
}
}
inline
ScheduleNode
*
Schedule
::
operator
->
()
{
return
static_cast
<
ScheduleNode
*>
(
node_
.
get
());
}
inline
const
IterVarRelationNode
*
IterVarRelation
::
operator
->
()
const
{
inline
const
IterVarRelationNode
*
IterVarRelation
::
operator
->
()
const
{
return
static_cast
<
const
IterVarRelationNode
*>
(
node_
.
get
());
return
static_cast
<
const
IterVarRelationNode
*>
(
node_
.
get
());
...
...
python/tvm/build.py
View file @
d114dfc9
...
@@ -63,7 +63,6 @@ def build(sch,
...
@@ -63,7 +63,6 @@ def build(sch,
arg_list
.
append
(
x
)
arg_list
.
append
(
x
)
else
:
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
# lowering
# lowering
bounds
=
schedule
.
InferBound
(
sch
)
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
...
...
python/tvm/schedule.py
View file @
d114dfc9
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from
._ctypes._node
import
NodeBase
,
register_node
from
._ctypes._node
import
NodeBase
,
register_node
from
.
import
_api_internal
from
.
import
_api_internal
from
.
import
tensor
as
_tensor
from
.
import
tensor
as
_tensor
from
.
import
collections
as
_collections
@register_node
@register_node
class
Buffer
(
NodeBase
):
class
Buffer
(
NodeBase
):
...
@@ -41,6 +42,53 @@ class Schedule(NodeBase):
...
@@ -41,6 +42,53 @@ class Schedule(NodeBase):
"""
"""
_api_internal
.
_ScheduleNormalize
(
self
)
_api_internal
.
_ScheduleNormalize
(
self
)
def
cache_read
(
self
,
tensor
,
scope
,
readers
):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if
isinstance
(
readers
,
(
_tensor
.
Tensor
,
_tensor
.
Operation
)):
readers
=
[
readers
]
readers
=
[
t
.
op
if
isinstance
(
t
,
_tensor
.
Tensor
)
else
t
for
t
in
readers
]
return
_api_internal
.
_ScheduleCacheRead
(
self
,
tensor
,
scope
,
readers
)
def
cache_write
(
self
,
tensor
,
scope
):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return
_api_internal
.
_ScheduleCacheWrite
(
self
,
tensor
,
scope
)
@register_node
@register_node
class
Stage
(
NodeBase
):
class
Stage
(
NodeBase
):
"""A Stage represents schedule for one operation."""
"""A Stage represents schedule for one operation."""
...
@@ -104,6 +152,18 @@ class Stage(NodeBase):
...
@@ -104,6 +152,18 @@ class Stage(NodeBase):
"""
"""
return
_api_internal
.
_StageSetScope
(
self
,
scope
)
return
_api_internal
.
_StageSetScope
(
self
,
scope
)
def
outermost_threads
(
self
,
threads
):
"""Force launch threads at outermost scope of the stage.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if
isinstance
(
threads
,
_collections
.
IterVar
):
threads
=
[
threads
]
_api_internal
.
_StageOutermostThreads
(
self
,
threads
)
def
compute_at
(
self
,
parent
,
scope
):
def
compute_at
(
self
,
parent
,
scope
):
"""Attach the stage at parent's scope
"""Attach the stage at parent's scope
...
...
src/api/api_lang.cc
View file @
d114dfc9
...
@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
...
@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
TVM_REGISTER_API
(
_Placeholder
)
TVM_REGISTER_API
(
_Placeholder
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
P
laceholder
(
args
[
0
],
*
ret
=
p
laceholder
(
args
[
0
],
args
[
1
],
args
[
1
],
args
[
2
]);
args
[
2
]);
});
});
...
@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
...
@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
});
TVM_REGISTER_API
(
_StageOutermostThreads
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
outermost_threads
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageUnroll
)
TVM_REGISTER_API
(
_StageUnroll
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
...
@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.
normalize
();
.
normalize
();
});
});
TVM_REGISTER_API
(
_ScheduleCacheRead
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_read
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
_ScheduleCacheWrite
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_write
(
args
[
1
],
args
[
2
]);
});
}
// namespace tvm
}
// namespace tvm
src/lang/operation.cc
View file @
d114dfc9
...
@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
...
@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
Tensor
P
laceholder
(
Array
<
Expr
>
shape
,
Type
dtype
,
std
::
string
name
)
{
Tensor
p
laceholder
(
Array
<
Expr
>
shape
,
Type
dtype
,
std
::
string
name
)
{
return
PlaceholderOpNode
::
make
(
name
,
shape
,
dtype
).
output
(
0
);
return
PlaceholderOpNode
::
make
(
name
,
shape
,
dtype
).
output
(
0
);
}
}
...
@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
...
@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return
Array
<
Expr
>
(
shape
);
return
Array
<
Expr
>
(
shape
);
}
}
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
)
{
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
)
{
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
// compute dimension.
size_t
ndim
=
shape
.
size
();
size_t
ndim
=
shape
.
size
();
...
@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
...
@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return
Operation
(
n
);
return
Operation
(
n
);
}
}
Array
<
Tensor
>
S
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
s
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
state_placeholder
,
...
...
src/schedule/auto_inline_elem_wise.cc
View file @
d114dfc9
...
@@ -6,9 +6,11 @@
...
@@ -6,9 +6,11 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
schedule
{
using
namespace
ir
;
class
ElemWiseDetector
:
public
IRVisitor
{
class
ElemWiseDetector
:
public
ir
::
IRVisitor
{
public
:
public
:
explicit
ElemWiseDetector
(
Array
<
IterVar
>
axis
)
:
axis_
(
axis
)
{}
explicit
ElemWiseDetector
(
Array
<
IterVar
>
axis
)
:
axis_
(
axis
)
{}
...
@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
...
@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
}
}
for
(
size_t
i
=
0
;
i
<
axis_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axis_
.
size
();
++
i
)
{
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if
(
!
axis
[
i
].
same_as
(
axis_
[
i
]
->
var
))
{
if
(
!
axis
[
i
].
same_as
(
axis_
[
i
]
->
var
))
{
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_
=
false
;
is_elem_wise_
=
false
;
return
;
return
;
}
}
...
@@ -52,21 +51,9 @@ bool IsElemWise(const Operation& op) {
...
@@ -52,21 +51,9 @@ bool IsElemWise(const Operation& op) {
return
false
;
return
false
;
}
}
}
// namespace ir
namespace
schedule
{
void
AutoInlineElemWise
(
Schedule
sch
)
{
void
AutoInlineElemWise
(
Schedule
sch
)
{
for
(
Stage
s
:
sch
->
stages
)
{
for
(
Stage
s
:
sch
->
stages
)
{
if
(
!
s
.
is_scheduled
()
&&
ir
::
IsElemWise
(
s
->
op
))
{
if
(
!
s
.
is_scheduled
()
&&
IsElemWise
(
s
->
op
)
&&
!
s
->
is_output
)
{
bool
is_root
=
false
;
for
(
auto
r
:
sch
->
roots
)
{
if
(
r
==
s
->
op
)
{
is_root
=
true
;
break
;
}
}
if
(
!
is_root
)
s
.
compute_inline
();
s
.
compute_inline
();
}
}
}
}
...
...
src/schedule/bound.cc
View file @
d114dfc9
...
@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
...
@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
const
TensorDom
&
d
=
tmap
.
at
(
output
[
i
]);
const
TensorDom
&
d
=
tmap
.
at
(
output
[
i
]);
time_dom
.
insert
(
time_dom
.
end
(),
d
.
data
[
0
].
begin
(),
d
.
data
[
0
].
end
());
time_dom
.
insert
(
time_dom
.
end
(),
d
.
data
[
0
].
begin
(),
d
.
data
[
0
].
end
());
}
}
LOG
(
INFO
)
<<
time_dom
.
size
();
CHECK
(
!
rmap
->
count
(
scan
->
scan_axis
));
CHECK
(
!
rmap
->
count
(
scan
->
scan_axis
));
Range
sdom
=
scan
->
scan_axis
->
dom
;
Range
sdom
=
scan
->
scan_axis
->
dom
;
Range
r
=
arith
::
Union
(
time_dom
).
cover_range
(
sdom
);
Range
r
=
arith
::
Union
(
time_dom
).
cover_range
(
sdom
);
...
@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
...
@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
const
TensorDom
&
tdom
=
tmap
.
at
(
op
.
output
(
0
));
const
TensorDom
&
tdom
=
tmap
.
at
(
op
.
output
(
0
));
for
(
size_t
i
=
0
;
i
<
compute
->
axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
compute
->
axis
.
size
();
++
i
)
{
Range
r
=
arith
::
Union
(
tdom
.
data
[
i
]
).
cover_range
(
compute
->
axis
[
i
]
->
dom
);
Range
r
=
arith
::
Union
(
tdom
.
data
.
at
(
i
)
).
cover_range
(
compute
->
axis
[
i
]
->
dom
);
CHECK
(
!
rmap
->
count
(
compute
->
axis
[
i
]));
CHECK
(
!
rmap
->
count
(
compute
->
axis
[
i
]));
(
*
rmap
)[
compute
->
axis
[
i
]]
=
r
;
(
*
rmap
)[
compute
->
axis
[
i
]]
=
r
;
}
}
...
@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
...
@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
direct_consume_by_parent
=
true
;
direct_consume_by_parent
=
true
;
}
}
}
}
}
else
{
LOG
(
INFO
)
<<
"not in feed graph consumer = "
<<
stage
->
op
;
}
}
}
}
// The relax set
// The relax set
...
@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
...
@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
}
}
FeedGraph
CreateFeedGraph
(
const
Schedule
&
sch
)
{
FeedGraph
CreateFeedGraph
(
const
Schedule
&
sch
)
{
auto
g
=
CreateReadGraph
(
sch
->
roots
);
Array
<
Operation
>
roots
;
for
(
Operation
op
:
sch
->
outputs
)
{
roots
.
push_back
(
sch
->
stage_map
[
op
]
->
op
);
}
auto
g
=
CreateReadGraph
(
roots
);
FeedGraph
fg
;
FeedGraph
fg
;
for
(
auto
kv
:
g
)
{
for
(
auto
kv
:
g
)
{
for
(
Tensor
t
:
kv
.
second
)
{
for
(
Tensor
t
:
kv
.
second
)
{
...
@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
...
@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
Map
<
IterVar
,
Range
>
InferBound
(
const
Schedule
&
sch
)
{
Map
<
IterVar
,
Range
>
InferBound
(
const
Schedule
&
sch
)
{
FeedGraph
feed_graph
=
CreateFeedGraph
(
sch
);
FeedGraph
feed_graph
=
CreateFeedGraph
(
sch
);
AttachPath
attach_path
=
CreateAttachPath
(
sch
);
AttachPath
attach_path
=
CreateAttachPath
(
sch
);
std
::
unordered_map
<
IterVar
,
Range
>
ret
;
std
::
unordered_map
<
IterVar
,
Range
>
ret
;
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
...
...
src/schedule/schedule_lang.cc
View file @
d114dfc9
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
* \file schedule.cc
* \file schedule.cc
*/
*/
#include <tvm/schedule.h>
#include <tvm/schedule.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "./graph.h"
#include "./graph.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -10,7 +12,8 @@ namespace tvm {
...
@@ -10,7 +12,8 @@ namespace tvm {
namespace
{
namespace
{
// find first occurance location in leaf
// find first occurance location in leaf
size_t
FindIterVar
(
ArrayNode
*
array_node
,
const
IterVar
&
v
)
{
template
<
typename
T
>
size_t
FindNodeRef
(
ArrayNode
*
array_node
,
const
T
&
v
)
{
const
Node
*
n
=
v
.
get
();
const
Node
*
n
=
v
.
get
();
for
(
size_t
i
=
0
;
i
<
array_node
->
data
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
array_node
->
data
.
size
();
++
i
)
{
if
(
array_node
->
data
[
i
].
get
()
==
n
)
return
i
;
if
(
array_node
->
data
[
i
].
get
()
==
n
)
return
i
;
...
@@ -19,10 +22,10 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
...
@@ -19,10 +22,10 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
}
}
size_t
FindLeafVar
(
ArrayNode
*
all_vars
,
ArrayNode
*
leaf_vars
,
const
IterVar
&
v
)
{
size_t
FindLeafVar
(
ArrayNode
*
all_vars
,
ArrayNode
*
leaf_vars
,
const
IterVar
&
v
)
{
size_t
pos
=
Find
IterVar
(
leaf_vars
,
v
);
size_t
pos
=
Find
NodeRef
(
leaf_vars
,
v
);
if
(
pos
<
leaf_vars
->
data
.
size
())
return
pos
;
if
(
pos
<
leaf_vars
->
data
.
size
())
return
pos
;
if
(
Find
IterVar
(
all_vars
,
v
)
<
all_vars
->
data
.
size
())
{
if
(
Find
NodeRef
(
all_vars
,
v
)
<
all_vars
->
data
.
size
())
{
LOG
(
FATAL
)
<<
"Operate on iter var "
<<
v
LOG
(
FATAL
)
<<
"Operate on iter var "
<<
v
<<
"that has already been splitted"
;
<<
"that has already been splitted"
;
}
else
{
}
else
{
...
@@ -68,8 +71,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -68,8 +71,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
Stage
::
Stage
(
Operation
op
)
{
Stage
::
Stage
(
Operation
op
)
{
auto
n
=
std
::
make_shared
<
StageNode
>
();
auto
n
=
std
::
make_shared
<
StageNode
>
();
n
->
op
=
op
;
n
->
op
=
op
;
n
->
origin_op
=
op
;
n
->
all_iter_vars
=
op
->
root_iter_vars
();
n
->
all_iter_vars
=
op
->
root_iter_vars
();
n
->
leaf_iter_vars
=
op
->
root_iter_vars
()
;
n
->
leaf_iter_vars
=
n
->
all_iter_vars
;
node_
=
n
;
node_
=
n
;
}
}
...
@@ -89,7 +93,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
...
@@ -89,7 +93,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
}
}
CHECK
(
found
)
CHECK
(
found
)
<<
"Cannot find the
specified axis in parent stage's leaf_iter_var
s"
;
<<
"Cannot find the
axis in parent's leaf_iter_vars or outermost_thread
s"
;
return
*
this
;
return
*
this
;
}
}
...
@@ -176,13 +180,63 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
...
@@ -176,13 +180,63 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return
*
this
;
return
*
this
;
}
}
Stage
&
Stage
::
outermost_threads
(
Array
<
IterVar
>
threads
)
{
StageNode
*
self
=
operator
->
();
CHECK
(
self
->
op
.
as
<
ScanOpNode
>
())
<<
"outermost_threads is only valid for composite ops such as ScanOp"
;
CHECK_EQ
(
self
->
outermost_threads
.
size
(),
0U
)
<<
"Already set outermost_threads"
;
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
temp
;
for
(
IterVar
iv
:
threads
)
{
temp
.
push_back
(
iv
.
node_
);
}
leaf_vars
->
data
.
insert
(
leaf_vars
->
data
.
begin
(),
temp
.
begin
(),
temp
.
end
());
all_vars
->
data
.
insert
(
all_vars
->
data
.
end
(),
temp
.
begin
(),
temp
.
end
());
(
*
this
)
->
outermost_threads
=
threads
;
return
*
this
;
}
inline
void
SetAttr
(
StageNode
*
self
,
IterVar
var
,
IterVarAttr
attr
)
{
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
FindLeafVar
(
all_vars
,
leaf_vars
,
var
);
auto
it
=
self
->
iter_var_attrs
.
find
(
var
);
if
(
it
!=
self
->
iter_var_attrs
.
end
())
{
CHECK_EQ
((
*
it
).
second
->
iter_type
,
attr
->
iter_type
)
<<
"IterVar's is already set to "
<<
(
*
it
).
second
<<
" instead of "
<<
attr
;
}
else
{
self
->
iter_var_attrs
.
Set
(
var
,
attr
);
}
}
Stage
&
Stage
::
vectorize
(
IterVar
var
)
{
// NOLINT(*)
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kVectorized
));
return
*
this
;
}
Stage
&
Stage
::
unroll
(
IterVar
var
)
{
// NOLINT(*)
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kUnrolled
));
return
*
this
;
}
Schedule
::
Schedule
(
Array
<
Operation
>
ops
)
{
Schedule
::
Schedule
(
Array
<
Operation
>
ops
)
{
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
roots
=
ops
;
n
->
outputs
=
ops
;
auto
g
=
schedule
::
CreateReadGraph
(
n
->
roots
);
auto
g
=
schedule
::
CreateReadGraph
(
n
->
outputs
);
Array
<
Operation
>
post_order
=
schedule
::
PostDFSOrder
(
n
->
roots
,
g
);
Array
<
Operation
>
post_order
=
schedule
::
PostDFSOrder
(
n
->
outputs
,
g
);
// output set.
std
::
unordered_set
<
Operation
>
output_set
;
for
(
Operation
x
:
ops
)
{
output_set
.
insert
(
x
);
}
for
(
Operation
op
:
post_order
)
{
for
(
Operation
op
:
post_order
)
{
Stage
stage
(
op
);
Stage
stage
(
op
);
stage
->
is_output
=
output_set
.
count
(
op
);
n
->
stages
.
push_back
(
stage
);
n
->
stages
.
push_back
(
stage
);
n
->
stage_map
.
Set
(
op
,
stage
);
n
->
stage_map
.
Set
(
op
,
stage
);
}
}
...
@@ -237,7 +291,7 @@ void Schedule::normalize() {
...
@@ -237,7 +291,7 @@ void Schedule::normalize() {
ArrayNode
*
leaf_vars
=
s
->
leaf_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
s
->
leaf_iter_vars
.
CopyOnWrite
();
for
(
IterVar
iv
:
root_iter_vars
)
{
for
(
IterVar
iv
:
root_iter_vars
)
{
size_t
idx
=
Find
IterVar
(
leaf_vars
,
iv
);
size_t
idx
=
Find
NodeRef
(
leaf_vars
,
iv
);
if
(
idx
<
leaf_vars
->
data
.
size
())
{
if
(
idx
<
leaf_vars
->
data
.
size
())
{
// insert rebase
// insert rebase
IterVar
rebased
(
Range
(),
iv
->
var
->
name_hint
+
".rb"
);
IterVar
rebased
(
Range
(),
iv
->
var
->
name_hint
+
".rb"
);
...
@@ -262,35 +316,197 @@ IterVarAttr::IterVarAttr(IterVarType t) {
...
@@ -262,35 +316,197 @@ IterVarAttr::IterVarAttr(IterVarType t) {
node_
=
n
;
node_
=
n
;
}
}
inline
void
SetAttr
(
StageNode
*
self
,
IterVar
var
,
IterVarAttr
attr
)
{
TVM_REGISTER_NODE_TYPE
(
StageNode
);
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
TVM_REGISTER_NODE_TYPE
(
IterVarAttrNode
);
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
FindLeafVar
(
all_vars
,
leaf_vars
,
var
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
auto
it
=
self
->
iter_var_attrs
.
find
(
var
);
TVM_REGISTER_NODE_TYPE
(
RebaseNode
);
if
(
it
!=
self
->
iter_var_attrs
.
end
())
{
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
CHECK_EQ
((
*
it
).
second
->
iter_type
,
attr
->
iter_type
)
<<
"IterVar's is already set to "
using
ir
::
TensorKey
;
<<
(
*
it
).
second
<<
" instead of "
<<
attr
;
// The replacer of cache.
class
TensorReplacer
:
public
ir
::
IRMutator
{
public
:
TensorReplacer
(
const
std
::
unordered_map
<
TensorKey
,
Tensor
>&
vmap
)
:
vmap_
(
vmap
)
{}
Expr
Mutate_
(
const
ir
::
Call
*
op
,
const
Expr
&
e
)
{
if
(
op
->
call_type
==
ir
::
Call
::
Halide
)
{
ir
::
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
vmap_
.
find
(
key
);
if
(
it
!=
vmap_
.
end
())
{
Expr
ret
=
ir
::
Call
::
make
(
op
->
type
,
it
->
second
->
op
->
name
,
op
->
args
,
op
->
call_type
,
it
->
second
->
op
,
it
->
second
->
value_index
);
found
=
true
;
return
IRMutator
::
Mutate_
(
ret
.
as
<
ir
::
Call
>
(),
ret
);
}
}
return
IRMutator
::
Mutate_
(
op
,
e
);
}
// whether it is found.
bool
found
{
false
};
private
:
const
std
::
unordered_map
<
TensorKey
,
Tensor
>&
vmap_
;
};
class
VarReplacer
:
public
ir
::
IRMutator
{
public
:
explicit
VarReplacer
(
const
std
::
unordered_map
<
const
Variable
*
,
Expr
>&
vsub
)
:
vsub_
(
vsub
)
{}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
{
auto
it
=
vsub_
.
find
(
op
);
if
(
it
!=
vsub_
.
end
())
return
it
->
second
;
return
e
;
}
private
:
const
std
::
unordered_map
<
const
Variable
*
,
Expr
>&
vsub_
;
};
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
void
ReplaceDataFlow
(
const
Array
<
Stage
>&
stages
,
std
::
unordered_map
<
TensorKey
,
Tensor
>*
vmap
)
{
for
(
Stage
s
:
stages
)
{
if
(
s
->
op
.
as
<
ComputeOpNode
>
())
{
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
TensorReplacer
repl
(
*
vmap
);
Expr
body
=
repl
.
Mutate
(
compute
->
body
);
if
(
repl
.
found
)
{
Operation
op
=
ComputeOpNode
::
make
(
compute
->
name
,
compute
->
axis
,
body
);
(
*
vmap
)[
TensorKey
{
s
->
op
,
0
}]
=
op
.
output
(
0
);
s
->
op
=
op
;
}
}
else
if
(
s
->
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
std
::
shared_ptr
<
ScanOpNode
>
n
=
std
::
make_shared
<
ScanOpNode
>
(
*
scan
);
// copy on write semantics ganrantees correctness
for
(
size_t
i
=
0
;
i
<
n
->
init
.
size
();
++
i
)
{
TensorKey
key
{
n
->
init
[
i
]
->
op
,
n
->
init
[
i
]
->
value_index
};
if
(
vmap
->
count
(
key
))
{
n
->
init
.
Set
(
i
,
vmap
->
at
(
key
));
}
}
for
(
size_t
i
=
0
;
i
<
n
->
update
.
size
();
++
i
)
{
TensorKey
key
{
n
->
update
[
i
]
->
op
,
n
->
update
[
i
]
->
value_index
};
if
(
vmap
->
count
(
key
))
{
n
->
update
.
Set
(
i
,
vmap
->
at
(
key
));
}
}
if
(
!
n
->
init
.
same_as
(
scan
->
init
)
||
!
n
->
update
.
same_as
(
scan
->
update
))
{
Operation
op
(
n
);
for
(
int
i
=
0
;
i
<
op
->
num_outputs
();
++
i
)
{
(
*
vmap
)[
TensorKey
{
s
->
op
,
i
}]
=
op
.
output
(
i
);
}
s
->
op
=
op
;
}
}
else
if
(
s
->
op
.
as
<
PlaceholderOpNode
>
())
{
}
else
{
}
else
{
self
->
iter_var_attrs
.
Set
(
var
,
attr
);
LOG
(
FATAL
)
<<
"unhandled problem"
;
}
}
}
}
}
Stage
&
Stage
::
vectorize
(
IterVar
var
)
{
// NOLINT(*)
Tensor
Schedule
::
cache_read
(
const
Tensor
&
tensor
,
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kVectorized
));
const
std
::
string
&
scope
,
return
*
this
;
const
Array
<
Operation
>&
readers
)
{
// create identity mapping.
std
::
ostringstream
os
;
os
<<
tensor
->
op
->
name
;
if
(
tensor
->
op
->
num_outputs
()
!=
1
)
{
os
<<
".v"
<<
tensor
->
value_index
;
}
os
<<
"."
<<
scope
;
Tensor
cache
=
compute
(
tensor
->
shape
,
[
&
tensor
](
const
Array
<
Var
>&
i
)
{
return
tensor
(
Array
<
Expr
>
(
i
.
begin
(),
i
.
end
()));
},
os
.
str
());
std
::
unordered_map
<
TensorKey
,
Tensor
>
vsub
;
vsub
[
TensorKey
{
tensor
->
op
,
tensor
->
value_index
}]
=
cache
;
std
::
unordered_map
<
TensorKey
,
Tensor
>
vmap
;
for
(
Operation
op
:
readers
)
{
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
CHECK
(
compute
)
<<
"cache read only take ComputeOp as readers"
;
Stage
s
=
operator
[](
op
);
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
TensorReplacer
repl
(
vsub
);
Expr
body
=
repl
.
Mutate
(
compute
->
body
);
CHECK
(
repl
.
found
)
<<
"Cannot find "
<<
tensor
<<
" in the body of specified reader"
<<
op
;
Operation
repl_op
=
ComputeOpNode
::
make
(
compute
->
name
,
compute
->
axis
,
body
);
vmap
[
TensorKey
{
s
->
op
,
0
}]
=
repl_op
.
output
(
0
);
s
->
op
=
repl_op
;
}
ReplaceDataFlow
((
*
this
)
->
stages
,
&
vmap
);
ArrayNode
*
stages
=
(
*
this
)
->
stages
.
CopyOnWrite
();
size_t
pos
=
FindNodeRef
(
stages
,
operator
[](
tensor
->
op
));
Stage
cache_stage
=
Stage
(
cache
->
op
);
cache_stage
.
set_scope
(
scope
);
CHECK_LT
(
pos
,
stages
->
data
.
size
());
stages
->
data
.
insert
(
stages
->
data
.
begin
()
+
pos
+
1
,
cache_stage
.
node_
);
(
*
this
)
->
stage_map
.
Set
(
cache
->
op
,
cache_stage
);
return
cache
;
}
}
Stage
&
Stage
::
unroll
(
IterVar
var
)
{
// NOLINT(*)
Tensor
Schedule
::
cache_write
(
const
Tensor
&
tensor
,
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kUnrolled
));
const
std
::
string
&
scope
)
{
return
*
this
;
Stage
orig_stage
=
operator
[](
tensor
->
op
);
const
ComputeOpNode
*
compute
=
tensor
->
op
.
as
<
ComputeOpNode
>
();
CHECK
(
compute
)
<<
"cache write only take ComputeOp as writers"
;
CHECK
(
!
orig_stage
.
is_scheduled
())
<<
"Create cache_write before doing split/fuse/reorder"
;
compute
=
orig_stage
->
op
.
as
<
ComputeOpNode
>
();
CHECK
(
compute
);
Array
<
Expr
>
args
;
Array
<
IterVar
>
new_axis
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vsub
;
for
(
IterVar
iv
:
compute
->
axis
)
{
args
.
push_back
(
iv
->
var
);
IterVar
new_iv
(
iv
->
dom
,
iv
->
var
->
name_hint
+
".c"
);
new_axis
.
push_back
(
new_iv
);
vsub
[
iv
->
var
.
get
()]
=
new_iv
->
var
;
}
VarReplacer
repl
(
vsub
);
Expr
body
=
repl
.
Mutate
(
compute
->
body
);
Operation
cache_op
=
ComputeOpNode
::
make
(
compute
->
name
+
"."
+
scope
,
new_axis
,
body
);
Tensor
cache_tensor
=
cache_op
.
output
(
0
);
Operation
orig_new_op
=
ComputeOpNode
::
make
(
compute
->
name
,
compute
->
axis
,
cache_tensor
(
args
));
std
::
unordered_map
<
TensorKey
,
Tensor
>
vmap
;
vmap
[
TensorKey
{
orig_stage
->
op
,
0
}]
=
orig_new_op
.
output
(
0
);
ReplaceDataFlow
((
*
this
)
->
stages
,
&
vmap
);
// mutate orig stage
orig_stage
->
op
=
orig_new_op
;
orig_stage
->
all_iter_vars
=
orig_stage
->
op
->
root_iter_vars
();
orig_stage
->
leaf_iter_vars
=
orig_stage
->
all_iter_vars
;
// create schedule for new cached stage.
ArrayNode
*
stages
=
(
*
this
)
->
stages
.
CopyOnWrite
();
size_t
pos
=
FindNodeRef
(
stages
,
orig_stage
);
Stage
cache_stage
=
Stage
(
cache_op
);
cache_stage
.
set_scope
(
scope
);
CHECK_LT
(
pos
,
stages
->
data
.
size
());
stages
->
data
.
insert
(
stages
->
data
.
begin
()
+
pos
,
cache_stage
.
node_
);
(
*
this
)
->
stage_map
.
Set
(
cache_op
,
cache_stage
);
return
cache_tensor
;
}
}
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
IterVarAttrNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
RebaseNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
}
// namespace tvm
}
// namespace tvm
src/schedule/schedule_ops.cc
View file @
d114dfc9
...
@@ -23,7 +23,8 @@ using namespace ir;
...
@@ -23,7 +23,8 @@ using namespace ir;
// Two private scope marks
// Two private scope marks
namespace
attr
{
namespace
attr
{
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
constexpr
const
char
*
scan_scope
=
"scan_scope"
;
constexpr
const
char
*
scan_update_scope
=
"scan_update_scope"
;
constexpr
const
char
*
scan_init_scope
=
"scan_init_scope"
;
}
// namespace attr
}
// namespace attr
/*!
/*!
...
@@ -280,23 +281,31 @@ Stmt MakeLoop(const Stage& s,
...
@@ -280,23 +281,31 @@ Stmt MakeLoop(const Stage& s,
if
(
init
.
defined
())
{
if
(
init
.
defined
())
{
// try to find the location to insert the initialization.
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
// Fuse the initialization and provide loop when possible.
std
::
unordered_map
<
IterVar
,
int
>
reduc
e_state
;
std
::
unordered_map
<
IterVar
,
int
>
updat
e_state
;
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
if
(
compute
)
{
for
(
IterVar
iv
:
compute
->
reduce_axis
)
{
for
(
IterVar
iv
:
compute
->
reduce_axis
)
{
reduc
e_state
[
iv
]
=
2
;
updat
e_state
[
iv
]
=
2
;
}
}
for
(
IterVar
iv
:
compute
->
axis
)
{
for
(
IterVar
iv
:
compute
->
axis
)
{
reduce_state
[
iv
]
=
1
;
update_state
[
iv
]
=
1
;
}
}
else
if
(
scan
)
{
update_state
[
scan
->
scan_axis
]
=
2
;
for
(
IterVar
iv
:
s
->
outermost_threads
)
{
update_state
[
iv
]
=
1
;
}
}
}
// find which iter var is related to reduction and which is related to axis.
// find which iter var is related to reduction and which is related to axis.
PassDownFlag
(
s
,
&
reduc
e_state
);
PassDownFlag
(
s
,
&
updat
e_state
);
auto
leaf_iter_vars
=
s
->
leaf_iter_vars
;
auto
leaf_iter_vars
=
s
->
leaf_iter_vars
;
std
::
unordered_map
<
IterVar
,
Expr
>
init_value_map
;
std
::
unordered_map
<
IterVar
,
Expr
>
init_value_map
;
// first first loop that is related to reduction.
// first first loop that is related to reduction.
size_t
begin_loop
=
leaf_iter_vars
.
size
();
size_t
begin_loop
=
leaf_iter_vars
.
size
();
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
leaf_iter_vars
.
size
();
++
i
)
{
auto
iv
=
leaf_iter_vars
[
i
];
auto
iv
=
leaf_iter_vars
[
i
];
int
flag
=
reduc
e_state
.
at
(
iv
);
int
flag
=
updat
e_state
.
at
(
iv
);
if
((
flag
&
2
)
!=
0
)
{
if
((
flag
&
2
)
!=
0
)
{
begin_loop
=
i
;
break
;
begin_loop
=
i
;
break
;
}
}
...
@@ -304,7 +313,7 @@ Stmt MakeLoop(const Stage& s,
...
@@ -304,7 +313,7 @@ Stmt MakeLoop(const Stage& s,
}
}
// skip loops that does not relates to axis.
// skip loops that does not relates to axis.
std
::
unordered_map
<
IterVar
,
bool
>
skip_iter
;
std
::
unordered_map
<
IterVar
,
bool
>
skip_iter
;
for
(
auto
kv
:
reduc
e_state
)
{
for
(
auto
kv
:
updat
e_state
)
{
int
flag
=
kv
.
second
;
int
flag
=
kv
.
second
;
if
((
flag
&
1
)
==
0
)
skip_iter
[
kv
.
first
]
=
true
;
if
((
flag
&
1
)
==
0
)
skip_iter
[
kv
.
first
]
=
true
;
}
}
...
@@ -422,7 +431,10 @@ Stmt MakePipeline(const Stage& s,
...
@@ -422,7 +431,10 @@ Stmt MakePipeline(const Stage& s,
}
else
if
(
scan
)
{
}
else
if
(
scan
)
{
// Provide is done by the sub operations.
// Provide is done by the sub operations.
provide
=
AttrStmt
::
make
(
provide
=
AttrStmt
::
make
(
s
->
op
,
attr
::
scan_scope
,
scan
->
scan_axis
->
var
,
s
->
op
,
attr
::
scan_update_scope
,
scan
->
scan_axis
->
var
,
Evaluate
::
make
(
0
));
init
=
AttrStmt
::
make
(
s
->
op
,
attr
::
scan_init_scope
,
0
,
Evaluate
::
make
(
0
));
Evaluate
::
make
(
0
));
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"not supported op "
<<
s
->
op
->
type_key
();
LOG
(
FATAL
)
<<
"not supported op "
<<
s
->
op
->
type_key
();
...
@@ -472,7 +484,9 @@ class InjectAttach : public IRMutator {
...
@@ -472,7 +484,9 @@ class InjectAttach : public IRMutator {
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
!=
nullptr
&&
if
(
op
!=
nullptr
&&
op
->
type_key
==
attr
::
loop_scope
)
{
op
->
type_key
==
attr
::
loop_scope
)
{
if
(
op
->
node
==
stage_
->
attach_ivar
)
{
CHECK_NE
(
producer_
.
size
(),
0U
);
if
(
op
->
node
==
stage_
->
attach_ivar
&&
producer_
.
back
()
==
stage_
->
attach_stage
->
op
.
get
())
{
CHECK
(
!
found_attach
);
CHECK
(
!
found_attach
);
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
...
@@ -482,6 +496,16 @@ class InjectAttach : public IRMutator {
...
@@ -482,6 +496,16 @@ class InjectAttach : public IRMutator {
}
}
return
stmt
;
return
stmt
;
}
}
Stmt
Mutate_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
is_producer
)
{
producer_
.
push_back
(
op
->
func
.
get
());
Stmt
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
producer_
.
pop_back
();
return
ret
;
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
// whether attach point is found
// whether attach point is found
bool
found_attach
{
false
};
bool
found_attach
{
false
};
...
@@ -490,6 +514,8 @@ class InjectAttach : public IRMutator {
...
@@ -490,6 +514,8 @@ class InjectAttach : public IRMutator {
const
Stage
&
stage_
;
const
Stage
&
stage_
;
// domain map
// domain map
const
Map
<
IterVar
,
Range
>&
dom_map_
;
const
Map
<
IterVar
,
Range
>&
dom_map_
;
// internal stack about realization scope.
std
::
vector
<
const
Node
*>
producer_
;
};
};
// inject the operator's realization on the stmt.
// inject the operator's realization on the stmt.
...
@@ -505,21 +531,11 @@ class InjectScanStep : public IRMutator {
...
@@ -505,21 +531,11 @@ class InjectScanStep : public IRMutator {
Stmt
Mutate
(
Stmt
stmt
)
final
{
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
CHECK
(
stmt
.
defined
());
stmt
=
IRMutator
::
Mutate
(
stmt
);
stmt
=
IRMutator
::
Mutate
(
stmt
);
if
(
is_init_
)
{
const
ProducerConsumer
*
op
=
stmt
.
as
<
ProducerConsumer
>
();
if
(
op
!=
nullptr
&&
op
->
is_producer
&&
op
->
func
.
same_as
(
scan_op_
))
{
stmt
=
ProducerConsumer
::
make
(
op
->
func
,
true
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
found_attach
=
true
;
}
}
else
{
// update
// update
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
!=
nullptr
&&
if
(
op
!=
nullptr
&&
op
->
type_key
==
attr
::
scan_scope
)
{
((
op
->
type_key
==
attr
::
scan_update_scope
&&
!
is_init_
)
||
(
op
->
type_key
==
attr
::
scan_init_scope
&&
is_init_
)))
{
if
(
op
->
node
.
same_as
(
scan_op_
))
{
if
(
op
->
node
.
same_as
(
scan_op_
))
{
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
...
@@ -527,7 +543,6 @@ class InjectScanStep : public IRMutator {
...
@@ -527,7 +543,6 @@ class InjectScanStep : public IRMutator {
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
}
}
}
}
}
return
stmt
;
return
stmt
;
}
}
...
@@ -561,8 +576,15 @@ Stmt InjectInline(const Operation op, Stmt body) {
...
@@ -561,8 +576,15 @@ Stmt InjectInline(const Operation op, Stmt body) {
class
SchedulePostProc
:
public
IRMutator
{
class
SchedulePostProc
:
public
IRMutator
{
public
:
public
:
Stmt
Mutate_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
ProducerConsumer
*
op
,
const
Stmt
&
s
)
final
{
if
(
to_remove_
.
count
(
op
->
func
.
get
()))
{
auto
it
=
replace_op_
.
find
(
op
->
func
.
get
());
return
this
->
Mutate
(
op
->
body
);
if
(
it
!=
replace_op_
.
end
())
{
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
it
->
second
.
defined
())
{
return
ProducerConsumer
::
make
(
it
->
second
,
op
->
is_producer
,
body
);
}
else
{
return
body
;
}
}
else
{
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
...
@@ -579,23 +601,40 @@ class SchedulePostProc : public IRMutator {
...
@@ -579,23 +601,40 @@ class SchedulePostProc : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type_key
==
attr
::
loop_scope
)
{
if
(
op
->
type_key
==
attr
::
loop_scope
)
{
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
attr
::
scan_scope
)
{
}
else
if
(
op
->
type_key
==
attr
::
scan_init_scope
)
{
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
attr
::
scan_update_scope
)
{
const
ScanOpNode
*
scan
=
op
->
node
.
as
<
ScanOpNode
>
();
const
ScanOpNode
*
scan
=
op
->
node
.
as
<
ScanOpNode
>
();
CHECK
(
scan
);
CHECK
(
scan
);
var_value_
[
scan
->
scan_axis
->
var
.
get
()]
=
op
->
value
;
var_value_
[
scan
->
scan_axis
->
var
.
get
()]
=
op
->
value
;
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type_key
==
ir
::
attr
::
realize_scope
)
{
}
else
if
(
op
->
type_key
==
ir
::
attr
::
realize_scope
)
{
if
(
to_remove_
.
count
(
op
->
node
.
get
()))
{
auto
it
=
replace_op_
.
find
(
op
->
node
.
get
());
if
(
it
!=
replace_op_
.
end
())
{
if
(
it
->
second
.
defined
())
{
Stmt
ret
=
AttrStmt
::
make
(
it
->
second
,
op
->
type_key
,
op
->
value
,
op
->
body
);
return
this
->
Mutate_
(
ret
.
as
<
AttrStmt
>
(),
ret
);
}
else
{
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
}
}
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Realize
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
Realize
*
op
,
const
Stmt
&
s
)
final
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
TensorKey
key
{
op
->
func
,
op
->
value_index
};
if
(
replace_
.
count
(
key
))
{
auto
it
=
replace_realize_
.
find
(
key
);
if
(
it
!=
replace_realize_
.
end
())
{
if
(
it
->
second
.
defined
())
{
Stmt
ret
=
Realize
::
make
(
it
->
second
->
op
,
it
->
second
->
value_index
,
op
->
type
,
op
->
bounds
,
op
->
condition
,
op
->
body
);
return
this
->
Mutate_
(
ret
.
as
<
Realize
>
(),
ret
);
}
else
{
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
}
else
{
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
...
@@ -603,8 +642,8 @@ class SchedulePostProc : public IRMutator {
...
@@ -603,8 +642,8 @@ class SchedulePostProc : public IRMutator {
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
replace_
.
find
(
key
);
auto
it
=
replace_
buffer_
.
find
(
key
);
if
(
it
!=
replace_
.
end
())
{
if
(
it
!=
replace_
buffer_
.
end
())
{
const
Tensor
&
dst
=
it
->
second
.
first
;
const
Tensor
&
dst
=
it
->
second
.
first
;
Stmt
ret
=
Provide
::
make
(
Stmt
ret
=
Provide
::
make
(
dst
->
op
,
dst
->
value_index
,
op
->
value
,
dst
->
op
,
dst
->
value_index
,
op
->
value
,
...
@@ -616,10 +655,10 @@ class SchedulePostProc : public IRMutator {
...
@@ -616,10 +655,10 @@ class SchedulePostProc : public IRMutator {
}
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
!=
nullptr
&&
op
->
call_type
==
Call
::
Halide
)
{
if
(
op
->
call_type
==
Call
::
Halide
)
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
replace_
.
find
(
key
);
auto
it
=
replace_
buffer_
.
find
(
key
);
if
(
it
!=
replace_
.
end
())
{
if
(
it
!=
replace_
buffer_
.
end
())
{
const
Tensor
&
dst
=
it
->
second
.
first
;
const
Tensor
&
dst
=
it
->
second
.
first
;
Expr
ret
=
Call
::
make
(
Expr
ret
=
Call
::
make
(
op
->
type
,
dst
->
op
->
name
,
op
->
type
,
dst
->
op
->
name
,
...
@@ -642,22 +681,32 @@ class SchedulePostProc : public IRMutator {
...
@@ -642,22 +681,32 @@ class SchedulePostProc : public IRMutator {
void
Init
(
const
Schedule
&
sch
)
{
void
Init
(
const
Schedule
&
sch
)
{
for
(
Stage
s
:
sch
->
stages
)
{
for
(
Stage
s
:
sch
->
stages
)
{
if
(
s
->
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
if
(
!
scan
)
continue
;
for
(
size_t
i
=
0
;
i
<
scan
->
update
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
scan
->
update
.
size
();
++
i
)
{
Tensor
t
=
s
->
op
.
output
(
i
);
Tensor
t
=
s
->
origin_
op
.
output
(
i
);
AddReplace
(
scan
->
init
[
i
],
t
,
Expr
());
AddReplace
(
scan
->
init
[
i
],
t
,
Expr
());
AddReplace
(
scan
->
update
[
i
],
t
,
scan
->
scan_axis
->
var
);
AddReplace
(
scan
->
update
[
i
],
t
,
scan
->
scan_axis
->
var
);
AddReplace
(
scan
->
state_placeholder
[
i
],
t
,
Expr
());
AddReplace
(
scan
->
state_placeholder
[
i
],
t
,
Expr
());
}
}
}
else
if
(
!
s
->
op
.
same_as
(
s
->
origin_op
))
{
Tensor
target
=
s
->
origin_op
.
output
(
0
);
AddReplace
(
s
->
op
.
output
(
0
),
target
,
Expr
(),
target
,
s
->
origin_op
);
}
}
}
}
}
private
:
private
:
void
AddReplace
(
Tensor
src
,
Tensor
dst
,
Expr
head_idx
)
{
void
AddReplace
(
Tensor
src
,
replace_
[
TensorKey
{
src
->
op
,
src
->
value_index
}]
Tensor
dst
,
=
std
::
make_pair
(
dst
,
head_idx
);
Expr
head_idx
,
to_remove_
.
insert
(
src
->
op
.
get
());
Tensor
repl_realize
=
Tensor
(),
Operation
repl_op
=
Operation
())
{
TensorKey
key
{
src
->
op
,
src
->
value_index
};
replace_buffer_
[
key
]
=
std
::
make_pair
(
dst
,
head_idx
);
replace_realize_
[
key
]
=
repl_realize
;
replace_op_
[
src
->
op
.
get
()]
=
repl_op
;
}
}
Array
<
Expr
>
RewriteArgs
(
Expr
head
,
Array
<
Expr
>
args
)
{
Array
<
Expr
>
RewriteArgs
(
Expr
head
,
Array
<
Expr
>
args
)
{
if
(
!
head
.
defined
())
return
args
;
if
(
!
head
.
defined
())
return
args
;
...
@@ -670,9 +719,11 @@ class SchedulePostProc : public IRMutator {
...
@@ -670,9 +719,11 @@ class SchedulePostProc : public IRMutator {
// The scan value
// The scan value
std
::
unordered_map
<
const
Variable
*
,
Expr
>
var_value_
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
var_value_
;
// buffer replacement
// buffer replacement
std
::
unordered_map
<
TensorKey
,
std
::
pair
<
Tensor
,
Expr
>
>
replace_
;
std
::
unordered_map
<
TensorKey
,
std
::
pair
<
Tensor
,
Expr
>
>
replace_buffer_
;
// replaced functions
// buffere realization to be replaced
std
::
unordered_set
<
const
Node
*>
to_remove_
;
std
::
unordered_map
<
TensorKey
,
Tensor
>
replace_realize_
;
// replace producer consumer.
std
::
unordered_map
<
const
Node
*
,
Operation
>
replace_op_
;
};
};
Stmt
ScheduleOps
(
Stmt
ScheduleOps
(
...
@@ -724,7 +775,9 @@ Stmt ScheduleOps(
...
@@ -724,7 +775,9 @@ Stmt ScheduleOps(
InjectAttach
mutator
(
s
,
dom_map
);
InjectAttach
mutator
(
s
,
dom_map
);
body
=
mutator
.
Mutate
(
body
);
body
=
mutator
.
Mutate
(
body
);
CHECK
(
mutator
.
found_attach
)
CHECK
(
mutator
.
found_attach
)
<<
"did not find attachment point"
;
<<
"did not find attachment point for "
<<
s
<<
" in"
<<
s
->
attach_stage
->
op
<<
" x "
<<
body
;
}
}
}
}
SchedulePostProc
post_proc
;
SchedulePostProc
post_proc
;
...
...
tests/cpp/tensor_test.cc
View file @
d114dfc9
...
@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
...
@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
using
namespace
tvm
;
using
namespace
tvm
;
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Tensor
A
=
P
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
A
=
p
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
B
=
P
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
Tensor
B
=
p
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
auto
C
=
C
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
auto
C
=
c
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
A
[
i
][
j
];
return
A
[
i
][
j
];
},
"C"
);
},
"C"
);
...
@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
...
@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
TEST
(
Tensor
,
Reduce
)
{
TEST
(
Tensor
,
Reduce
)
{
using
namespace
tvm
;
using
namespace
tvm
;
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Tensor
A
=
P
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
A
=
p
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
B
=
P
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
Tensor
B
=
p
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
auto
C
=
C
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
auto
C
=
c
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
max
(
1
+
A
[
i
][
rv
]
+
1
,
B
[
j
][
rv
]),
{
rv
});
return
sum
(
max
(
1
+
A
[
i
][
rv
]
+
1
,
B
[
j
][
rv
]),
{
rv
});
},
"C"
);
},
"C"
);
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
...
...
tests/python/integration/test_gemm.py
View file @
d114dfc9
...
@@ -2,17 +2,6 @@ import tvm
...
@@ -2,17 +2,6 @@ import tvm
from
tvm.addon
import
nvcc_compiler
from
tvm.addon
import
nvcc_compiler
import
numpy
as
np
import
numpy
as
np
@tvm.register_func
def
tvm_callback_cuda_compile
(
code
):
ptx
=
nvcc_compiler
.
compile_source
(
code
,
target
=
"ptx"
)
print
(
ptx
.
decode
(
"utf-8"
))
return
ptx
@tvm.register_func
def
tvm_callback_cuda_postproc
(
code
):
print
(
code
)
return
code
def
test_gemm
():
def
test_gemm
():
# graph
# graph
nn
=
1024
nn
=
1024
...
@@ -22,21 +11,14 @@ def test_gemm():
...
@@ -22,21 +11,14 @@ def test_gemm():
l
=
n
l
=
n
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
m
,
l
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
m
,
l
),
name
=
'B'
)
AA
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
),
name
=
"AA"
)
BB
=
tvm
.
compute
(
B
.
shape
,
lambda
*
i
:
B
(
*
i
),
name
=
"BB"
)
k
=
tvm
.
IterVar
((
0
,
l
),
name
=
'k'
)
k
=
tvm
.
IterVar
((
0
,
l
),
name
=
'k'
)
C
C
=
tvm
.
compute
(
C
=
tvm
.
compute
(
(
n
,
m
),
(
n
,
m
),
lambda
ii
,
jj
:
tvm
.
sum
(
A
A
[
ii
,
k
]
*
B
B
[
jj
,
k
],
axis
=
k
),
lambda
ii
,
jj
:
tvm
.
sum
(
A
[
ii
,
k
]
*
B
[
jj
,
k
],
axis
=
k
),
name
=
'CC'
)
name
=
'CC'
)
C
=
tvm
.
compute
(
CC
.
shape
,
lambda
*
i
:
CC
(
*
i
),
name
=
"C"
)
# schedule
# schedule
s
=
tvm
.
Schedule
(
C
.
op
)
s
=
tvm
.
Schedule
(
C
.
op
)
xtile
,
ytile
=
32
,
32
xtile
,
ytile
=
32
,
32
s
[
AA
]
.
set_scope
(
"shared"
)
s
[
BB
]
.
set_scope
(
"shared"
)
scale
=
8
scale
=
8
num_thread
=
8
num_thread
=
8
block_factor
=
scale
*
num_thread
block_factor
=
scale
*
num_thread
...
@@ -45,6 +27,9 @@ def test_gemm():
...
@@ -45,6 +27,9 @@ def test_gemm():
block_y
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.y"
)
block_y
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.y"
)
thread_y
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.y"
)
thread_y
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.y"
)
CC
=
s
.
cache_write
(
C
,
"local"
)
AA
=
s
.
cache_read
(
A
,
"shared"
,
[
CC
])
BB
=
s
.
cache_read
(
B
,
"shared"
,
[
CC
])
_
,
yi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
block_factor
,
outer
=
block_y
)
_
,
yi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
block_factor
,
outer
=
block_y
)
_
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
1
],
factor
=
block_factor
,
outer
=
block_x
)
_
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
1
],
factor
=
block_factor
,
outer
=
block_x
)
s
[
C
]
.
reorder
(
block_y
,
block_x
,
yi
,
xi
)
s
[
C
]
.
reorder
(
block_y
,
block_x
,
yi
,
xi
)
...
@@ -92,8 +77,8 @@ def test_gemm():
...
@@ -92,8 +77,8 @@ def test_gemm():
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
check_device
(
"cuda"
)
check_device
(
"cuda"
)
#
tvm.init_opencl()
tvm
.
init_opencl
()
#
check_device("opencl")
check_device
(
"opencl"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_gemm
()
test_gemm
()
tests/python/unittest/test_lang_schedule.py
View file @
d114dfc9
...
@@ -22,13 +22,13 @@ def test_schedule_create():
...
@@ -22,13 +22,13 @@ def test_schedule_create():
json_str
=
tvm
.
save_json
(
s
)
json_str
=
tvm
.
save_json
(
s
)
s_loaded
=
tvm
.
load_json
(
json_str
)
s_loaded
=
tvm
.
load_json
(
json_str
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
roots
[
0
]
.
body
)
==
str
(
s
.
roo
ts
[
0
]
.
body
))
assert
(
str
(
s_loaded
.
outputs
[
0
]
.
body
)
==
str
(
s
.
outpu
ts
[
0
]
.
body
))
# pickle unpickle
# pickle unpickle
dump
=
pkl
.
dumps
(
s
)
dump
=
pkl
.
dumps
(
s
)
s_loaded
=
pkl
.
loads
(
dump
)
s_loaded
=
pkl
.
loads
(
dump
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
roots
[
0
]
.
body
)
==
str
(
s
.
roo
ts
[
0
]
.
body
))
assert
(
str
(
s_loaded
.
outputs
[
0
]
.
body
)
==
str
(
s
.
outpu
ts
[
0
]
.
body
))
def
test_reorder
():
def
test_reorder
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
...
tests/python/unittest/test_schedule_schedule_ops.py
View file @
d114dfc9
...
@@ -74,6 +74,20 @@ def test_auto_inline():
...
@@ -74,6 +74,20 @@ def test_auto_inline():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
test_schedule_cache
():
m
=
tvm
.
Var
(
'm'
)
n
=
tvm
.
Var
(
'n'
)
A
=
tvm
.
placeholder
((
m
,
n
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
m
,
n
),
name
=
'B'
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
A
(
i
,
j
)
*
B
(
i
,
j
),
name
=
'C'
)
s
=
tvm
.
Schedule
(
C
.
op
)
AA
=
s
.
cache_read
(
A
,
"shared"
,
readers
=
[
C
])
CC
=
s
.
cache_write
(
C
,
"shared"
)
s
[
AA
]
.
compute_at
(
s
[
CC
],
CC
.
op
.
axis
[
0
])
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_schedule_scan
()
test_schedule_scan
()
...
@@ -81,3 +95,4 @@ if __name__ == "__main__":
...
@@ -81,3 +95,4 @@ if __name__ == "__main__":
test_schedule1
()
test_schedule1
()
test_schedule2
()
test_schedule2
()
test_auto_inline
()
test_auto_inline
()
test_schedule_cache
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment