Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
357ad592
Commit
357ad592
authored
Dec 02, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix Schedule structure, refactor compute to all rely on iter var
parent
3a48b323
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
171 additions
and
175 deletions
+171
-175
include/tvm/expr.h
+5
-4
include/tvm/operation.h
+5
-4
include/tvm/schedule.h
+112
-19
include/tvm/split.h
+0
-65
include/tvm/tensor.h
+2
-4
python/tvm/function.py
+4
-4
src/c_api/c_api_lang.cc
+1
-8
src/lang/expr.cc
+3
-3
src/lang/operation.cc
+15
-17
src/lang/schedule.cc
+22
-0
src/lang/split.cc
+0
-20
src/pass/ir_mutator.cc
+2
-1
src/pass/schedule_ops.cc
+0
-26
No files found.
include/tvm/expr.h
View file @
357ad592
...
...
@@ -133,13 +133,13 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
*/
class
IterVarNode
:
public
Node
{
public
:
/*! \brief The looping variable */
Var
var
;
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range
dom
;
/*! \brief The looping variable */
Var
var
;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
...
...
@@ -147,12 +147,13 @@ class IterVarNode : public Node {
std
::
string
thread_tag
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"dom"
,
&
dom
);
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"thread_tag"
,
&
thread_tag
);
}
static
IterVar
make
(
Var
var
,
Range
dom
,
std
::
string
thread_tag
);
static
IterVar
make
(
Range
dom
,
Var
var
,
std
::
string
thread_tag
);
static
constexpr
const
char
*
_type_key
=
"IterVar"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IterVarNode
);
};
...
...
include/tvm/operation.h
View file @
357ad592
...
...
@@ -17,6 +17,8 @@ namespace tvm {
*/
class
ComputeOpNode
:
public
OperationNode
{
public
:
/*! \brief Iteration variables over the dimensions */
Array
<
IterVar
>
dim_var
;
/*! \brief the compute expression */
Expr
body
;
/*! \brief constructor */
...
...
@@ -25,19 +27,18 @@ class ComputeOpNode : public OperationNode {
size_t
num_outputs
()
const
final
{
return
1
;
}
Array
<
IterVar
>
root_iter_vars
()
const
final
;
std
::
string
output_name
(
size_t
i
)
const
final
;
Type
output_dtype
(
size_t
i
)
const
final
;
Array
<
Expr
>
output_shape
(
size_t
i
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"dim_var"
,
&
dim_var
);
v
->
Visit
(
"body"
,
&
body
);
}
static
Operation
make
(
Domain
domain
,
std
::
string
name
,
Array
<
Var
>
dim_var
,
static
Operation
make
(
std
::
string
name
,
Array
<
IterVar
>
dim_var
,
Expr
body
);
static
constexpr
const
char
*
_type_key
=
"ComputeOp"
;
...
...
include/tvm/schedule.h
View file @
357ad592
...
...
@@ -8,15 +8,14 @@
#include <string>
#include "./base.h"
#include "./split.h"
#include "./operation.h"
namespace
tvm
{
// Node container for Schedule
class
ScheduleNode
;
// Node container for
AttachSpec
class
AttachSpec
Node
;
// Node container for
IterVarRelation
class
IterVarRelation
Node
;
/*! \brief the attachment type */
enum
AttachType
:
int
{
...
...
@@ -38,42 +37,132 @@ class Schedule : public NodeRef {
inline
const
ScheduleNode
*
operator
->
()
const
;
};
/*!
* \brief The schedule relation between IterVars
* can be Split, Fuse.
*/
class
IterVarRelation
:
public
NodeRef
{
public
:
IterVarRelation
()
{}
explicit
IterVarRelation
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
IterVarRelationNode
*
operator
->
()
const
;
};
// defintion of node containers
/*! \brief represents the schedule of the tensor */
/*!
* \brief represents the schedule of the tensor
*
* A schedule is a Directed acylic hypergraph.
* With each node is represented by a IterVar,
* and each hyper-edge is represented by a IterVarRelation.
*
* The relations can be Split/Fuse.
*
* The current data structure stores the hyper graph in its
* bipartite representation.
*
* The relations connects the IterVars in the graph.
*/
class
ScheduleNode
:
public
Node
{
public
:
/*! \brief The operation to be scheduled */
Operation
op
;
/*! \brief The thread scope level of the schedule */
std
::
string
scope
;
/*! \brief Splits over iteration domains */
Array
<
Split
>
splits
;
/*! \brief All the nodes in the iter var */
Array
<
IterVar
>
all_iter_vars
;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array
<
IterVar
>
leaf_iter_vars
;
/*! \brief The relation bwteen of IterVars */
Array
<
IterVarRelation
>
relations
;
/*! \brief The attachment type of the schedule */
AttachType
attach_type
;
/*!
* \brief The attach point of this schedule, if it is a split
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
* \brief The attach point of this schedule.
*/
Split
attach_parent
;
IterVar
attach_parent
;
/*! \brief the schedules that this schedule depend on */
Array
<
Schedule
>
children
;
// the type key
const
char
*
type_key
()
const
final
{
return
"Schedule"
;
}
const
uint32_t
type_index
()
const
final
{
static
uint32_t
tidx
=
TypeKey2Index
(
type_key
());
return
tidx
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"splits"
,
&
splits
);
v
->
Visit
(
"all_iter_vars"
,
&
all_iter_vars
);
v
->
Visit
(
"leaf_iter_vars"
,
&
leaf_iter_vars
);
v
->
Visit
(
"relations"
,
&
relations
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_parent"
,
&
attach_parent
);
v
->
Visit
(
"children"
,
&
children
);
}
static
constexpr
const
char
*
_type_key
=
"Schedule"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ScheduleNode
);
};
/*! \brief base node of iteration var */
class
IterVarRelationNode
:
public
Node
{
};
/*!
* \brief Split the parent domain into product of
* outer and iter.
*/
class
SplitNode
:
public
IterVarRelationNode
{
public
:
/*! \brief The parent domain */
IterVar
parent
;
/*! \brief The outer domain */
IterVar
outer
;
/*! \brief The inner domain */
IterVar
inner
;
/*! \brief The split factor */
Expr
factor
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"parent"
,
&
parent
);
v
->
Visit
(
"outer"
,
&
outer
);
v
->
Visit
(
"inner"
,
&
inner
);
v
->
Visit
(
"factor"
,
&
factor
);
}
static
IterVarRelation
make
(
IterVar
parent
,
IterVar
outer
,
IterVar
inner
,
Expr
factor
);
static
constexpr
const
char
*
_type_key
=
"Split"
;
TVM_DECLARE_NODE_TYPE_INFO
(
SplitNode
);
};
/*!
* \brief Fuse two domains into one domain.
*/
class
FuseNode
:
public
IterVarRelationNode
{
public
:
/*! \brief The outer domain */
IterVar
outer
;
/*! \brief The inner domain */
IterVar
inner
;
/*! \brief The target domain */
IterVar
fused
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"outer"
,
&
outer
);
v
->
Visit
(
"inner"
,
&
inner
);
v
->
Visit
(
"fused"
,
&
fused
);
}
static
IterVarRelation
make
(
IterVar
outer
,
IterVar
inner
,
IterVar
fused
);
static
constexpr
const
char
*
_type_key
=
"Fuse"
;
TVM_DECLARE_NODE_TYPE_INFO
(
FuseNode
);
};
// implementations
...
...
@@ -81,5 +170,9 @@ inline const ScheduleNode* Schedule::operator->() const {
return
static_cast
<
const
ScheduleNode
*>
(
node_
.
get
());
}
inline
const
IterVarRelationNode
*
IterVarRelation
::
operator
->
()
const
{
return
static_cast
<
const
IterVarRelationNode
*>
(
node_
.
get
());
}
}
// namespace tvm
#endif // TVM_SCHEDULE_H_
include/tvm/split.h
deleted
100644 → 0
View file @
3a48b323
/*!
* Copyright (c) 2016 by Contributors
* \file split.h
* \brief Define a split over Domain or RDomain
*/
#ifndef TVM_SPLIT_H_
#define TVM_SPLIT_H_
#include "./base.h"
#include "./expr.h"
namespace
tvm
{
// internal node container for split.
class
SplitNode
;
/*! \brief Split over input domain */
class
Split
:
public
NodeRef
{
public
:
/*! \brief default constructor */
Split
()
{}
explicit
Split
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
SplitNode
*
operator
->
()
const
;
};
/*!
* \brief base class of split node,
* specifies a split over domain
* split also defines how to generate
*/
class
SplitNode
:
public
Node
{
public
:
/*! \brief the variable to be splitted on */
Var
var
;
};
/*! \brief simple split node that splits over one dimension */
class
DimSplitNode
:
public
SplitNode
{
public
:
/*! \brief The factor of the split */
Expr
factor
;
/*! \brief constructor */
DimSplitNode
()
{}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"factor"
,
&
factor
);
}
static
Split
make
(
Var
var
,
Expr
factor
);
static
constexpr
const
char
*
_type_key
=
"DimSplit"
;
TVM_DECLARE_NODE_TYPE_INFO
(
DimSplitNode
);
};
// Implementations of inline functions
inline
const
SplitNode
*
Split
::
operator
->
()
const
{
return
static_cast
<
const
SplitNode
*>
(
node_
.
get
());
}
}
// namespace tvm
#endif // TVM_SPLIT_H_
include/tvm/tensor.h
View file @
357ad592
...
...
@@ -174,12 +174,10 @@ class TensorNode : public FunctionBaseNode {
*/
class
OperationNode
:
public
Node
{
public
:
/*! \brief The domain of iteration of this op. */
Domain
domain
;
/*! \brief iter-Var over the dimensions */
Array
<
Var
>
dim_var
;
/*! \brief optional name of the operation */
std
::
string
name
;
/*! \return the list of iteration variable at root */
virtual
Array
<
IterVar
>
root_iter_vars
()
const
=
0
;
/*! \return number of outputs of this op */
virtual
size_t
num_outputs
()
const
=
0
;
/*! \return name of i-th output */
...
...
python/tvm/function.py
View file @
357ad592
...
...
@@ -83,11 +83,11 @@ def compute(shape, fcompute, name="TensorCompute"):
arg_names
=
fcompute
.
__code__
.
co_varnames
if
ndim
!=
len
(
arg_names
):
raise
ValueError
(
"fcompute do not match dimension"
)
dim_var
=
[
Var
(
x
)
for
x
in
arg_names
]
body
=
fcompute
(
*
dim_var
)
dom
=
[
Range
(
0
,
x
)
for
x
in
shape
]
dim_var
=
[
IterVar
((
0
,
s
),
x
)
for
x
,
s
in
zip
(
arg_names
,
shape
)]
body
=
fcompute
(
*
[
v
.
var
for
v
in
dim_var
])
op_node
=
_function_internal
.
_ComputeOp
(
dom
,
name
,
dim_var
,
body
)
name
,
dim_var
,
body
)
return
_function_internal
.
_Tensor
(
shape
,
name
,
body
.
dtype
,
op_node
,
0
)
...
...
src/c_api/c_api_lang.cc
View file @
357ad592
...
...
@@ -5,7 +5,6 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include "./c_api_registry.h"
...
...
@@ -89,8 +88,7 @@ TVM_REGISTER_API(_ComputeOp)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
ComputeOpNode
::
make
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
),
args
.
at
(
3
));
args
.
at
(
2
));
});
...
...
@@ -100,11 +98,6 @@ TVM_REGISTER_API(_IterVar)
});
TVM_REGISTER_API
(
_DimSplit
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
DimSplitNode
::
make
(
args
.
at
(
0
),
args
.
at
(
1
));
});
TVM_REGISTER_API
(
_Schedule
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Schedule
(
args
.
at
(
0
),
args
.
at
(
1
));
...
...
src/lang/expr.cc
View file @
357ad592
...
...
@@ -24,12 +24,12 @@ Range Range::make_with_min_extent(Expr min, Expr extent) {
}
IterVar
::
IterVar
(
Range
dom
,
std
::
string
var_name
,
std
::
string
thread_tag
)
:
IterVar
(
IterVarNode
::
make
(
Var
(
var_name
,
Int
(
32
)),
dom
,
thread_tag
))
{}
:
IterVar
(
IterVarNode
::
make
(
dom
,
Var
(
var_name
,
Int
(
32
))
,
thread_tag
))
{}
IterVar
IterVarNode
::
make
(
Var
var
,
Range
dom
,
std
::
string
thread_tag
)
{
IterVar
IterVarNode
::
make
(
Range
dom
,
Var
var
,
std
::
string
thread_tag
)
{
std
::
shared_ptr
<
IterVarNode
>
n
=
std
::
make_shared
<
IterVarNode
>
();
n
->
var
=
var
;
n
->
dom
=
dom
;
n
->
var
=
var
;
n
->
thread_tag
=
thread_tag
;
return
IterVar
(
n
);
}
...
...
src/lang/operation.cc
View file @
357ad592
...
...
@@ -13,32 +13,25 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto
op_node
=
std
::
make_shared
<
ComputeOpNode
>
();
// compute dimension.
size_t
ndim
=
shape
.
size
();
std
::
vector
<
Var
>
dim_index
;
std
::
vector
<
IterVar
>
dim_var
;
std
::
vector
<
Var
>
args
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
std
::
ostringstream
os
;
os
<<
"dim_var"
<<
i
;
dim_index
.
push_back
(
Var
(
os
.
str
()));
dim_var
.
push_back
(
IterVar
(
Range
(
0
,
shape
[
i
]),
os
.
str
()));
args
.
push_back
(
dim_var
.
back
()
->
var
);
}
std
::
vector
<
Range
>
dom
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
dom
.
push_back
(
Range
(
0
,
shape
[
i
]));
}
op_node
->
dim_var
=
Array
<
Var
>
(
dim_index
);
op_node
->
domain
=
Domain
(
dom
);
op_node
->
body
=
fcompute
(
op_node
->
dim_var
);
op_node
->
dim_var
=
Array
<
IterVar
>
(
dim_var
);
op_node
->
body
=
fcompute
(
args
);
op_node
->
name
=
name
;
return
Operation
(
op_node
).
output
(
0
);
}
Operation
ComputeOpNode
::
make
(
Domain
domain
,
std
::
string
name
,
Array
<
Var
>
dim_var
,
Operation
ComputeOpNode
::
make
(
std
::
string
name
,
Array
<
IterVar
>
dim_var
,
Expr
body
)
{
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
n
->
domain
=
domain
;
n
->
name
=
name
;
n
->
dim_var
=
dim_var
;
n
->
body
=
body
;
...
...
@@ -55,6 +48,10 @@ Tensor Operation::output(size_t i) const {
return
Tensor
(
node
);
}
Array
<
IterVar
>
ComputeOpNode
::
root_iter_vars
()
const
{
return
dim_var
;
}
std
::
string
ComputeOpNode
::
output_name
(
size_t
i
)
const
{
CHECK_EQ
(
i
,
0
);
return
name
;
...
...
@@ -68,8 +65,9 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array
<
Expr
>
ComputeOpNode
::
output_shape
(
size_t
i
)
const
{
CHECK_EQ
(
i
,
0
);
std
::
vector
<
Expr
>
shape
;
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
shape
.
push_back
(
domain
[
i
]
->
extent
);
for
(
size_t
i
=
0
;
i
<
dim_var
.
size
();
++
i
)
{
const
Range
&
r
=
dim_var
[
i
]
->
dom
;
shape
.
push_back
(
r
->
extent
);
}
return
Array
<
Expr
>
(
shape
);
}
...
...
src/lang/schedule.cc
View file @
357ad592
...
...
@@ -13,6 +13,28 @@ Schedule::Schedule(Operation op, std::string scope) {
node_
=
n
;
}
IterVarRelation
SplitNode
::
make
(
IterVar
parent
,
IterVar
outer
,
IterVar
inner
,
Expr
factor
)
{
auto
n
=
std
::
make_shared
<
SplitNode
>
();
n
->
parent
=
parent
;
n
->
outer
=
outer
;
n
->
inner
=
inner
;
n
->
factor
=
factor
;
return
IterVarRelation
(
n
);
}
IterVarRelation
FuseNode
::
make
(
IterVar
outer
,
IterVar
inner
,
IterVar
fused
)
{
auto
n
=
std
::
make_shared
<
FuseNode
>
();
n
->
outer
=
outer
;
n
->
inner
=
inner
;
n
->
fused
=
fused
;
return
IterVarRelation
(
n
);
}
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
}
// namespace tvm
src/lang/split.cc
deleted
100644 → 0
View file @
3a48b323
/*!
* Copyright (c) 2016 by Contributors
* \file split.cc
*/
#include <tvm/split.h>
namespace
tvm
{
Split
DimSplitNode
::
make
(
Var
var
,
Expr
factor
)
{
auto
n
=
std
::
make_shared
<
DimSplitNode
>
();
CHECK_EQ
(
factor
.
type
().
lanes
(),
1
);
n
->
var
=
var
;
n
->
factor
=
factor
;
return
Split
(
n
);
}
TVM_REGISTER_NODE_TYPE
(
DimSplitNode
);
}
// namespace tvm
src/pass/ir_mutator.cc
View file @
357ad592
...
...
@@ -53,7 +53,8 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
new_dom
[
i
]
=
IterVarNode
::
make
(
v
->
var
,
Range
::
make_with_min_extent
(
new_min
,
new_extent
),
v
->
thread_tag
);
Range
::
make_with_min_extent
(
new_min
,
new_extent
),
v
->
var
,
v
->
thread_tag
);
}
if
(
!
changed
)
{
return
rdom
;
...
...
src/pass/schedule_ops.cc
View file @
357ad592
...
...
@@ -38,32 +38,6 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
return
body
;
}
void
MakeLoop
(
const
DimSplitNode
*
op
,
const
Split
&
s
,
Scope
<
AttrKey
,
Expr
>*
pscope
,
std
::
vector
<
Stmt
>*
nest
)
{
auto
&
scope
=
*
pscope
;
Expr
out_min
=
scope
[{
op
->
var
,
"min"
}];
Expr
out_ext
=
scope
[{
op
->
var
,
"extent"
}];
Expr
stride
=
op
->
factor
;
Var
offset
(
s
->
var
->
name_hint
+
".offset"
,
Int
(
32
));
// for loop with stride
// TODO(tqchen) split the loop to deal with tails
nest
->
emplace_back
(
For
::
make
(
offset
,
out_min
,
out_ext
,
ForType
::
Parallel
,
DeviceAPI
::
None
,
Stmt
()));
Expr
in_min
=
offset
+
out_min
;
Expr
in_ext
=
min
(
stride
,
out_ext
-
offset
);
// declare min and extent of the corresponding variable
nest
->
emplace_back
(
AttrStmt
::
make
(
op
->
var
,
"min"
,
in_min
,
Stmt
()));
nest
->
emplace_back
(
AttrStmt
::
make
(
op
->
var
,
"extent"
,
in_ext
,
Stmt
()));
// declare this is the loop
nest
->
emplace_back
(
AttrStmt
::
make
(
s
,
"split"
,
0
,
Stmt
()));
// setup the scope.
pscope
->
Push
({
op
->
var
,
"min"
},
in_min
);
pscope
->
Push
({
op
->
var
,
"extent"
},
in_ext
);
}
Stmt
MakePipeline
(
const
Schedule
&
sch
,
Stmt
body
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment