Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d114dfc9
Commit
d114dfc9
authored
Feb 16, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 16, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Mutate dataflow in schedule, refactor Stage (#44)
parent
820a8597
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
312 additions
and
135 deletions
+312
-135
include/tvm/operation.h
+11
-11
include/tvm/schedule.h
+62
-8
python/tvm/build.py
+0
-1
python/tvm/schedule.py
+60
-0
src/api/api_lang.cc
+19
-1
src/lang/operation.cc
+3
-3
src/schedule/auto_inline_elem_wise.cc
+6
-19
src/schedule/bound.cc
+9
-3
src/schedule/schedule_lang.cc
+0
-0
src/schedule/schedule_ops.cc
+112
-59
tests/cpp/tensor_test.cc
+6
-6
tests/python/integration/test_gemm.py
+7
-22
tests/python/unittest/test_lang_schedule.py
+2
-2
tests/python/unittest/test_schedule_schedule_ops.py
+15
-0
No files found.
include/tvm/operation.h
View file @
d114dfc9
...
@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
...
@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
* \param name The name of the Tensor.
*/
*/
Tensor
P
laceholder
(
Array
<
Expr
>
shape
,
Tensor
p
laceholder
(
Array
<
Expr
>
shape
,
Type
dtype
=
Float
(
32
),
Type
dtype
=
Float
(
32
),
std
::
string
name
=
"placeholder"
);
std
::
string
name
=
"placeholder"
);
...
@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
...
@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
*/
*/
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
/*!
/*!
* \brief Construct new tensors by scan over scan_axis.
* \brief Construct new tensors by scan over scan_axis.
...
@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
...
@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
* \param name The optional name of the tensor.
*/
*/
Array
<
Tensor
>
S
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
s
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
state_placeholder
,
std
::
string
name
=
"scan"
);
std
::
string
name
=
"scan"
);
// same as compute, specialized for different fcompute function
// same as compute, specialized for different fcompute function
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
inline
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
inline
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
{
std
::
string
name
=
"tensor"
)
{
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
FCompute
fc
=
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
return
C
ompute
(
shape
,
fc
,
name
);
return
c
ompute
(
shape
,
fc
,
name
);
}
}
}
// namespace tvm
}
// namespace tvm
...
...
include/tvm/schedule.h
View file @
d114dfc9
...
@@ -132,6 +132,13 @@ class Stage : public NodeRef {
...
@@ -132,6 +132,13 @@ class Stage : public NodeRef {
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
Expr
x_factor
,
Expr
y_factor
);
Expr
x_factor
,
Expr
y_factor
);
/*!
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage
&
outermost_threads
(
Array
<
IterVar
>
threads
);
/*!
* \brief Vectorize iteration.
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \param var The axis to be vectorized.
* \return reference to self.
* \return reference to self.
...
@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
...
@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
return
this
->
operator
[](
tensor
->
op
);
return
this
->
operator
[](
tensor
->
op
);
}
}
/*!
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor
cache_read
(
const
Tensor
&
tensor
,
const
std
::
string
&
scope
,
const
Array
<
Operation
>&
readers
);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor
cache_write
(
const
Tensor
&
tensor
,
const
std
::
string
&
scope
);
/*!
* \brief Normalize the schedule.
* \brief Normalize the schedule.
* This is needed before bound inference.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* Insert necessary RebaseNode to make sure all leaf_iter_vars
...
@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
...
@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
* \return the pointer to the internal node container
*/
*/
inline
const
ScheduleNode
*
operator
->
()
const
;
inline
const
ScheduleNode
*
operator
->
()
const
;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
ScheduleNode
*
operator
->
();
// declare container type
// declare container type
using
ContainerType
=
ScheduleNode
;
using
ContainerType
=
ScheduleNode
;
};
};
...
@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
...
@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
*/
*/
class
StageNode
:
public
Node
{
class
StageNode
:
public
Node
{
public
:
public
:
/*! \brief The operation to be scheduled */
Operation
op
;
/*! \brief The thread scope level of the stage */
/*! \brief The thread scope level of the stage */
std
::
string
scope
;
std
::
string
scope
;
/*! \brief The operation of stage, can be different from original op. */
Operation
op
;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation
origin_op
;
/*! \brief All the nodes in the iter var */
/*! \brief All the nodes in the iter var */
Array
<
IterVar
>
all_iter_vars
;
Array
<
IterVar
>
all_iter_vars
;
/*!
/*!
...
@@ -255,6 +295,11 @@ class StageNode : public Node {
...
@@ -255,6 +295,11 @@ class StageNode : public Node {
* Operations can only be performed in leaves.
* Operations can only be performed in leaves.
*/
*/
Array
<
IterVar
>
leaf_iter_vars
;
Array
<
IterVar
>
leaf_iter_vars
;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array
<
IterVar
>
outermost_threads
;
/*! \brief The relation bwteen of IterVars */
/*! \brief The relation bwteen of IterVars */
Array
<
IterVarRelation
>
relations
;
Array
<
IterVarRelation
>
relations
;
/*! \brief additional attributes about iter var. */
/*! \brief additional attributes about iter var. */
...
@@ -265,17 +310,22 @@ class StageNode : public Node {
...
@@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar
attach_ivar
;
IterVar
attach_ivar
;
/*! \brief The stage this node attaches to */
/*! \brief The stage this node attaches to */
Stage
attach_stage
;
Stage
attach_stage
;
/*! \brief Whether this is an output stage */
bool
is_output
{
false
};
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"origin_op"
,
&
origin_op
);
v
->
Visit
(
"all_iter_vars"
,
&
all_iter_vars
);
v
->
Visit
(
"all_iter_vars"
,
&
all_iter_vars
);
v
->
Visit
(
"leaf_iter_vars"
,
&
leaf_iter_vars
);
v
->
Visit
(
"leaf_iter_vars"
,
&
leaf_iter_vars
);
v
->
Visit
(
"outermost_threads"
,
&
outermost_threads
);
v
->
Visit
(
"relations"
,
&
relations
);
v
->
Visit
(
"relations"
,
&
relations
);
v
->
Visit
(
"iter_var_attrs"
,
&
iter_var_attrs
);
v
->
Visit
(
"iter_var_attrs"
,
&
iter_var_attrs
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_ivar"
,
&
attach_ivar
);
v
->
Visit
(
"attach_ivar"
,
&
attach_ivar
);
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
v
->
Visit
(
"attach_stage"
,
&
attach_stage
);
v
->
Visit
(
"is_output"
,
&
is_output
);
}
}
static
constexpr
const
char
*
_type_key
=
"Stage"
;
static
constexpr
const
char
*
_type_key
=
"Stage"
;
...
@@ -285,18 +335,18 @@ class StageNode : public Node {
...
@@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
/*! \brief node container for schedule */
class
ScheduleNode
:
public
Node
{
class
ScheduleNode
:
public
Node
{
public
:
public
:
/*! \brief The
root operations
*/
/*! \brief The
output operations in original data flow graph
*/
Array
<
Operation
>
roo
ts
;
Array
<
Operation
>
outpu
ts
;
/*!
/*!
* \brief list of all stages for non-placeholder ops
* \brief list of all stages for non-placeholder ops
.
*
The stage are ordered in PostDFS order of their op
.
*
The stages are sorted in dependency order
.
*/
*/
Array
<
Stage
>
stages
;
Array
<
Stage
>
stages
;
/*! \brief map of operation to the stages */
/*! \brief map of operation to the stages */
Map
<
Operation
,
Stage
>
stage_map
;
Map
<
Operation
,
Stage
>
stage_map
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"
roots"
,
&
roo
ts
);
v
->
Visit
(
"
outputs"
,
&
outpu
ts
);
v
->
Visit
(
"stages"
,
&
stages
);
v
->
Visit
(
"stages"
,
&
stages
);
v
->
Visit
(
"stage_map"
,
&
stage_map
);
v
->
Visit
(
"stage_map"
,
&
stage_map
);
}
}
...
@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
...
@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
inline
bool
Stage
::
is_scheduled
()
const
{
inline
bool
Stage
::
is_scheduled
()
const
{
const
StageNode
*
n
=
operator
->
();
const
StageNode
*
n
=
operator
->
();
return
!
(
n
->
relations
.
empty
()
&&
n
->
attach_type
==
kNone
);
return
!
(
n
->
relations
.
empty
()
&&
n
->
attach_type
==
kNone
&&
n
->
all_iter_vars
.
same_as
(
n
->
leaf_iter_vars
));
}
}
inline
const
ScheduleNode
*
Schedule
::
operator
->
()
const
{
inline
const
ScheduleNode
*
Schedule
::
operator
->
()
const
{
return
static_cast
<
const
ScheduleNode
*>
(
node_
.
get
());
return
static_cast
<
const
ScheduleNode
*>
(
node_
.
get
());
}
}
inline
ScheduleNode
*
Schedule
::
operator
->
()
{
return
static_cast
<
ScheduleNode
*>
(
node_
.
get
());
}
inline
const
IterVarRelationNode
*
IterVarRelation
::
operator
->
()
const
{
inline
const
IterVarRelationNode
*
IterVarRelation
::
operator
->
()
const
{
return
static_cast
<
const
IterVarRelationNode
*>
(
node_
.
get
());
return
static_cast
<
const
IterVarRelationNode
*>
(
node_
.
get
());
...
...
python/tvm/build.py
View file @
d114dfc9
...
@@ -63,7 +63,6 @@ def build(sch,
...
@@ -63,7 +63,6 @@ def build(sch,
arg_list
.
append
(
x
)
arg_list
.
append
(
x
)
else
:
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
# lowering
# lowering
bounds
=
schedule
.
InferBound
(
sch
)
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
...
...
python/tvm/schedule.py
View file @
d114dfc9
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from
._ctypes._node
import
NodeBase
,
register_node
from
._ctypes._node
import
NodeBase
,
register_node
from
.
import
_api_internal
from
.
import
_api_internal
from
.
import
tensor
as
_tensor
from
.
import
tensor
as
_tensor
from
.
import
collections
as
_collections
@register_node
@register_node
class
Buffer
(
NodeBase
):
class
Buffer
(
NodeBase
):
...
@@ -41,6 +42,53 @@ class Schedule(NodeBase):
...
@@ -41,6 +42,53 @@ class Schedule(NodeBase):
"""
"""
_api_internal
.
_ScheduleNormalize
(
self
)
_api_internal
.
_ScheduleNormalize
(
self
)
def
cache_read
(
self
,
tensor
,
scope
,
readers
):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if
isinstance
(
readers
,
(
_tensor
.
Tensor
,
_tensor
.
Operation
)):
readers
=
[
readers
]
readers
=
[
t
.
op
if
isinstance
(
t
,
_tensor
.
Tensor
)
else
t
for
t
in
readers
]
return
_api_internal
.
_ScheduleCacheRead
(
self
,
tensor
,
scope
,
readers
)
def
cache_write
(
self
,
tensor
,
scope
):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return
_api_internal
.
_ScheduleCacheWrite
(
self
,
tensor
,
scope
)
@register_node
@register_node
class
Stage
(
NodeBase
):
class
Stage
(
NodeBase
):
"""A Stage represents schedule for one operation."""
"""A Stage represents schedule for one operation."""
...
@@ -104,6 +152,18 @@ class Stage(NodeBase):
...
@@ -104,6 +152,18 @@ class Stage(NodeBase):
"""
"""
return
_api_internal
.
_StageSetScope
(
self
,
scope
)
return
_api_internal
.
_StageSetScope
(
self
,
scope
)
def
outermost_threads
(
self
,
threads
):
"""Force launch threads at outermost scope of the stage.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if
isinstance
(
threads
,
_collections
.
IterVar
):
threads
=
[
threads
]
_api_internal
.
_StageOutermostThreads
(
self
,
threads
)
def
compute_at
(
self
,
parent
,
scope
):
def
compute_at
(
self
,
parent
,
scope
):
"""Attach the stage at parent's scope
"""Attach the stage at parent's scope
...
...
src/api/api_lang.cc
View file @
d114dfc9
...
@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
...
@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
TVM_REGISTER_API
(
_Placeholder
)
TVM_REGISTER_API
(
_Placeholder
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
P
laceholder
(
args
[
0
],
*
ret
=
p
laceholder
(
args
[
0
],
args
[
1
],
args
[
1
],
args
[
2
]);
args
[
2
]);
});
});
...
@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
...
@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
*
ret
=
Array
<
IterVar
>
({
x_outer
,
y_outer
,
x_inner
,
y_inner
});
});
});
TVM_REGISTER_API
(
_StageOutermostThreads
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
outermost_threads
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageUnroll
)
TVM_REGISTER_API
(
_StageUnroll
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
...
@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.
normalize
();
.
normalize
();
});
});
TVM_REGISTER_API
(
_ScheduleCacheRead
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_read
(
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_API
(
_ScheduleCacheWrite
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Schedule
()
.
cache_write
(
args
[
1
],
args
[
2
]);
});
}
// namespace tvm
}
// namespace tvm
src/lang/operation.cc
View file @
d114dfc9
...
@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
...
@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
Tensor
P
laceholder
(
Array
<
Expr
>
shape
,
Type
dtype
,
std
::
string
name
)
{
Tensor
p
laceholder
(
Array
<
Expr
>
shape
,
Type
dtype
,
std
::
string
name
)
{
return
PlaceholderOpNode
::
make
(
name
,
shape
,
dtype
).
output
(
0
);
return
PlaceholderOpNode
::
make
(
name
,
shape
,
dtype
).
output
(
0
);
}
}
...
@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
...
@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return
Array
<
Expr
>
(
shape
);
return
Array
<
Expr
>
(
shape
);
}
}
Tensor
C
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
)
{
Tensor
c
ompute
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
)
{
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
// compute dimension.
size_t
ndim
=
shape
.
size
();
size_t
ndim
=
shape
.
size
();
...
@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
...
@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return
Operation
(
n
);
return
Operation
(
n
);
}
}
Array
<
Tensor
>
S
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
s
can
(
IterVar
scan_axis
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
init
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
update
,
Array
<
Tensor
>
state_placeholder
,
Array
<
Tensor
>
state_placeholder
,
...
...
src/schedule/auto_inline_elem_wise.cc
View file @
d114dfc9
...
@@ -6,9 +6,11 @@
...
@@ -6,9 +6,11 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
schedule
{
using
namespace
ir
;
class
ElemWiseDetector
:
public
IRVisitor
{
class
ElemWiseDetector
:
public
ir
::
IRVisitor
{
public
:
public
:
explicit
ElemWiseDetector
(
Array
<
IterVar
>
axis
)
:
axis_
(
axis
)
{}
explicit
ElemWiseDetector
(
Array
<
IterVar
>
axis
)
:
axis_
(
axis
)
{}
...
@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
...
@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
}
}
for
(
size_t
i
=
0
;
i
<
axis_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axis_
.
size
();
++
i
)
{
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if
(
!
axis
[
i
].
same_as
(
axis_
[
i
]
->
var
))
{
if
(
!
axis
[
i
].
same_as
(
axis_
[
i
]
->
var
))
{
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_
=
false
;
is_elem_wise_
=
false
;
return
;
return
;
}
}
...
@@ -52,22 +51,10 @@ bool IsElemWise(const Operation& op) {
...
@@ -52,22 +51,10 @@ bool IsElemWise(const Operation& op) {
return
false
;
return
false
;
}
}
}
// namespace ir
namespace
schedule
{
void
AutoInlineElemWise
(
Schedule
sch
)
{
void
AutoInlineElemWise
(
Schedule
sch
)
{
for
(
Stage
s
:
sch
->
stages
)
{
for
(
Stage
s
:
sch
->
stages
)
{
if
(
!
s
.
is_scheduled
()
&&
ir
::
IsElemWise
(
s
->
op
))
{
if
(
!
s
.
is_scheduled
()
&&
IsElemWise
(
s
->
op
)
&&
!
s
->
is_output
)
{
bool
is_root
=
false
;
s
.
compute_inline
();
for
(
auto
r
:
sch
->
roots
)
{
if
(
r
==
s
->
op
)
{
is_root
=
true
;
break
;
}
}
if
(
!
is_root
)
s
.
compute_inline
();
}
}
}
}
}
}
...
...
src/schedule/bound.cc
View file @
d114dfc9
...
@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
...
@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
const
TensorDom
&
d
=
tmap
.
at
(
output
[
i
]);
const
TensorDom
&
d
=
tmap
.
at
(
output
[
i
]);
time_dom
.
insert
(
time_dom
.
end
(),
d
.
data
[
0
].
begin
(),
d
.
data
[
0
].
end
());
time_dom
.
insert
(
time_dom
.
end
(),
d
.
data
[
0
].
begin
(),
d
.
data
[
0
].
end
());
}
}
LOG
(
INFO
)
<<
time_dom
.
size
();
CHECK
(
!
rmap
->
count
(
scan
->
scan_axis
));
CHECK
(
!
rmap
->
count
(
scan
->
scan_axis
));
Range
sdom
=
scan
->
scan_axis
->
dom
;
Range
sdom
=
scan
->
scan_axis
->
dom
;
Range
r
=
arith
::
Union
(
time_dom
).
cover_range
(
sdom
);
Range
r
=
arith
::
Union
(
time_dom
).
cover_range
(
sdom
);
...
@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
...
@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
const
ComputeOpNode
*
compute
=
op
.
as
<
ComputeOpNode
>
();
const
TensorDom
&
tdom
=
tmap
.
at
(
op
.
output
(
0
));
const
TensorDom
&
tdom
=
tmap
.
at
(
op
.
output
(
0
));
for
(
size_t
i
=
0
;
i
<
compute
->
axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
compute
->
axis
.
size
();
++
i
)
{
Range
r
=
arith
::
Union
(
tdom
.
data
[
i
]
).
cover_range
(
compute
->
axis
[
i
]
->
dom
);
Range
r
=
arith
::
Union
(
tdom
.
data
.
at
(
i
)
).
cover_range
(
compute
->
axis
[
i
]
->
dom
);
CHECK
(
!
rmap
->
count
(
compute
->
axis
[
i
]));
CHECK
(
!
rmap
->
count
(
compute
->
axis
[
i
]));
(
*
rmap
)[
compute
->
axis
[
i
]]
=
r
;
(
*
rmap
)[
compute
->
axis
[
i
]]
=
r
;
}
}
...
@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
...
@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
direct_consume_by_parent
=
true
;
direct_consume_by_parent
=
true
;
}
}
}
}
}
else
{
LOG
(
INFO
)
<<
"not in feed graph consumer = "
<<
stage
->
op
;
}
}
}
}
// The relax set
// The relax set
...
@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
...
@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
}
}
FeedGraph
CreateFeedGraph
(
const
Schedule
&
sch
)
{
FeedGraph
CreateFeedGraph
(
const
Schedule
&
sch
)
{
auto
g
=
CreateReadGraph
(
sch
->
roots
);
Array
<
Operation
>
roots
;
for
(
Operation
op
:
sch
->
outputs
)
{
roots
.
push_back
(
sch
->
stage_map
[
op
]
->
op
);
}
auto
g
=
CreateReadGraph
(
roots
);
FeedGraph
fg
;
FeedGraph
fg
;
for
(
auto
kv
:
g
)
{
for
(
auto
kv
:
g
)
{
for
(
Tensor
t
:
kv
.
second
)
{
for
(
Tensor
t
:
kv
.
second
)
{
...
@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
...
@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
Map
<
IterVar
,
Range
>
InferBound
(
const
Schedule
&
sch
)
{
Map
<
IterVar
,
Range
>
InferBound
(
const
Schedule
&
sch
)
{
FeedGraph
feed_graph
=
CreateFeedGraph
(
sch
);
FeedGraph
feed_graph
=
CreateFeedGraph
(
sch
);
AttachPath
attach_path
=
CreateAttachPath
(
sch
);
AttachPath
attach_path
=
CreateAttachPath
(
sch
);
std
::
unordered_map
<
IterVar
,
Range
>
ret
;
std
::
unordered_map
<
IterVar
,
Range
>
ret
;
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
...
...
src/schedule/schedule_lang.cc
View file @
d114dfc9
This diff is collapsed.
Click to expand it.
src/schedule/schedule_ops.cc
View file @
d114dfc9
This diff is collapsed.
Click to expand it.
tests/cpp/tensor_test.cc
View file @
d114dfc9
...
@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
...
@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
using
namespace
tvm
;
using
namespace
tvm
;
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Tensor
A
=
P
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
A
=
p
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
B
=
P
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
Tensor
B
=
p
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
auto
C
=
C
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
auto
C
=
c
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
A
[
i
][
j
];
return
A
[
i
][
j
];
},
"C"
);
},
"C"
);
...
@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
...
@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
TEST
(
Tensor
,
Reduce
)
{
TEST
(
Tensor
,
Reduce
)
{
using
namespace
tvm
;
using
namespace
tvm
;
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Tensor
A
=
P
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
A
=
p
laceholder
({
m
,
l
},
Float
(
32
),
"A"
);
Tensor
B
=
P
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
Tensor
B
=
p
laceholder
({
n
,
l
},
Float
(
32
),
"B"
);
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
auto
C
=
C
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
auto
C
=
c
ompute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
max
(
1
+
A
[
i
][
rv
]
+
1
,
B
[
j
][
rv
]),
{
rv
});
return
sum
(
max
(
1
+
A
[
i
][
rv
]
+
1
,
B
[
j
][
rv
]),
{
rv
});
},
"C"
);
},
"C"
);
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
...
...
tests/python/integration/test_gemm.py
View file @
d114dfc9
...
@@ -2,17 +2,6 @@ import tvm
...
@@ -2,17 +2,6 @@ import tvm
from
tvm.addon
import
nvcc_compiler
from
tvm.addon
import
nvcc_compiler
import
numpy
as
np
import
numpy
as
np
@tvm.register_func
def
tvm_callback_cuda_compile
(
code
):
ptx
=
nvcc_compiler
.
compile_source
(
code
,
target
=
"ptx"
)
print
(
ptx
.
decode
(
"utf-8"
))
return
ptx
@tvm.register_func
def
tvm_callback_cuda_postproc
(
code
):
print
(
code
)
return
code
def
test_gemm
():
def
test_gemm
():
# graph
# graph
nn
=
1024
nn
=
1024
...
@@ -22,21 +11,14 @@ def test_gemm():
...
@@ -22,21 +11,14 @@ def test_gemm():
l
=
n
l
=
n
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
m
,
l
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
m
,
l
),
name
=
'B'
)
AA
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
),
name
=
"AA"
)
BB
=
tvm
.
compute
(
B
.
shape
,
lambda
*
i
:
B
(
*
i
),
name
=
"BB"
)
k
=
tvm
.
IterVar
((
0
,
l
),
name
=
'k'
)
k
=
tvm
.
IterVar
((
0
,
l
),
name
=
'k'
)
C
C
=
tvm
.
compute
(
C
=
tvm
.
compute
(
(
n
,
m
),
(
n
,
m
),
lambda
ii
,
jj
:
tvm
.
sum
(
A
A
[
ii
,
k
]
*
B
B
[
jj
,
k
],
axis
=
k
),
lambda
ii
,
jj
:
tvm
.
sum
(
A
[
ii
,
k
]
*
B
[
jj
,
k
],
axis
=
k
),
name
=
'CC'
)
name
=
'CC'
)
C
=
tvm
.
compute
(
CC
.
shape
,
lambda
*
i
:
CC
(
*
i
),
name
=
"C"
)
# schedule
# schedule
s
=
tvm
.
Schedule
(
C
.
op
)
s
=
tvm
.
Schedule
(
C
.
op
)
xtile
,
ytile
=
32
,
32
xtile
,
ytile
=
32
,
32
s
[
AA
]
.
set_scope
(
"shared"
)
s
[
BB
]
.
set_scope
(
"shared"
)
scale
=
8
scale
=
8
num_thread
=
8
num_thread
=
8
block_factor
=
scale
*
num_thread
block_factor
=
scale
*
num_thread
...
@@ -45,6 +27,9 @@ def test_gemm():
...
@@ -45,6 +27,9 @@ def test_gemm():
block_y
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.y"
)
block_y
=
tvm
.
IterVar
(
thread_tag
=
"blockIdx.y"
)
thread_y
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.y"
)
thread_y
=
tvm
.
IterVar
((
0
,
num_thread
),
thread_tag
=
"threadIdx.y"
)
CC
=
s
.
cache_write
(
C
,
"local"
)
AA
=
s
.
cache_read
(
A
,
"shared"
,
[
CC
])
BB
=
s
.
cache_read
(
B
,
"shared"
,
[
CC
])
_
,
yi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
block_factor
,
outer
=
block_y
)
_
,
yi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
block_factor
,
outer
=
block_y
)
_
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
1
],
factor
=
block_factor
,
outer
=
block_x
)
_
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
1
],
factor
=
block_factor
,
outer
=
block_x
)
s
[
C
]
.
reorder
(
block_y
,
block_x
,
yi
,
xi
)
s
[
C
]
.
reorder
(
block_y
,
block_x
,
yi
,
xi
)
...
@@ -92,8 +77,8 @@ def test_gemm():
...
@@ -92,8 +77,8 @@ def test_gemm():
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
check_device
(
"cuda"
)
check_device
(
"cuda"
)
#
tvm.init_opencl()
tvm
.
init_opencl
()
#
check_device("opencl")
check_device
(
"opencl"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_gemm
()
test_gemm
()
tests/python/unittest/test_lang_schedule.py
View file @
d114dfc9
...
@@ -22,13 +22,13 @@ def test_schedule_create():
...
@@ -22,13 +22,13 @@ def test_schedule_create():
json_str
=
tvm
.
save_json
(
s
)
json_str
=
tvm
.
save_json
(
s
)
s_loaded
=
tvm
.
load_json
(
json_str
)
s_loaded
=
tvm
.
load_json
(
json_str
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
roots
[
0
]
.
body
)
==
str
(
s
.
roo
ts
[
0
]
.
body
))
assert
(
str
(
s_loaded
.
outputs
[
0
]
.
body
)
==
str
(
s
.
outpu
ts
[
0
]
.
body
))
# pickle unpickle
# pickle unpickle
dump
=
pkl
.
dumps
(
s
)
dump
=
pkl
.
dumps
(
s
)
s_loaded
=
pkl
.
loads
(
dump
)
s_loaded
=
pkl
.
loads
(
dump
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
roots
[
0
]
.
body
)
==
str
(
s
.
roo
ts
[
0
]
.
body
))
assert
(
str
(
s_loaded
.
outputs
[
0
]
.
body
)
==
str
(
s
.
outpu
ts
[
0
]
.
body
))
def
test_reorder
():
def
test_reorder
():
m
=
tvm
.
Var
(
'm'
)
m
=
tvm
.
Var
(
'm'
)
...
...
tests/python/unittest/test_schedule_schedule_ops.py
View file @
d114dfc9
...
@@ -74,6 +74,20 @@ def test_auto_inline():
...
@@ -74,6 +74,20 @@ def test_auto_inline():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
test_schedule_cache
():
m
=
tvm
.
Var
(
'm'
)
n
=
tvm
.
Var
(
'n'
)
A
=
tvm
.
placeholder
((
m
,
n
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
m
,
n
),
name
=
'B'
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
A
(
i
,
j
)
*
B
(
i
,
j
),
name
=
'C'
)
s
=
tvm
.
Schedule
(
C
.
op
)
AA
=
s
.
cache_read
(
A
,
"shared"
,
readers
=
[
C
])
CC
=
s
.
cache_write
(
C
,
"shared"
)
s
[
AA
]
.
compute_at
(
s
[
CC
],
CC
.
op
.
axis
[
0
])
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_schedule_scan
()
test_schedule_scan
()
...
@@ -81,3 +95,4 @@ if __name__ == "__main__":
...
@@ -81,3 +95,4 @@ if __name__ == "__main__":
test_schedule1
()
test_schedule1
()
test_schedule2
()
test_schedule2
()
test_auto_inline
()
test_auto_inline
()
test_schedule_cache
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment