Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
605813e4
Commit
605813e4
authored
Nov 27, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
schedule over operation
parent
cac1b5a8
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
47 additions
and
48 deletions
+47
-48
include/tvm/expr.h
+2
-0
include/tvm/operation.h
+5
-6
include/tvm/schedule.h
+6
-6
include/tvm/split.h
+5
-9
include/tvm/tensor.h
+6
-6
python/tvm/function.py
+4
-1
src/c_api/c_api_lang.cc
+1
-1
src/lang/operation.cc
+6
-6
src/lang/schedule.cc
+2
-2
src/lang/split.cc
+3
-5
src/lang/tensor.cc
+4
-4
tests/python/test_inline.py
+1
-1
tests/python/test_schedule.py
+2
-1
No files found.
include/tvm/expr.h
View file @
605813e4
...
@@ -37,6 +37,8 @@ class Var : public Halide::VarExpr {
...
@@ -37,6 +37,8 @@ class Var : public Halide::VarExpr {
public
:
public
:
explicit
Var
(
const
std
::
string
&
name_hint
=
"v"
,
explicit
Var
(
const
std
::
string
&
name_hint
=
"v"
,
Type
t
=
Int
(
32
))
:
VarExpr
(
name_hint
,
t
)
{}
Type
t
=
Int
(
32
))
:
VarExpr
(
name_hint
,
t
)
{}
explicit
Var
(
std
::
shared_ptr
<
Node
>
n
)
:
VarExpr
(
n
)
{}
};
};
}
// namespace tvm
}
// namespace tvm
...
...
include/tvm/operation.h
View file @
605813e4
...
@@ -10,7 +10,6 @@
...
@@ -10,7 +10,6 @@
#include "./expr.h"
#include "./expr.h"
#include "./domain.h"
#include "./domain.h"
namespace
tvm
{
namespace
tvm
{
// internal node container for Operation
// internal node container for Operation
...
@@ -38,15 +37,15 @@ class OperationNode : public Node {
...
@@ -38,15 +37,15 @@ class OperationNode : public Node {
Domain
domain
;
Domain
domain
;
/*! \brief optional name of the operation */
/*! \brief optional name of the operation */
std
::
string
name
;
std
::
string
name
;
/*! \brief index iteration variables on the domain of operation. */
Array
<
Var
>
iter_var
;
};
};
/*!
/*!
* \brief A Compute op that compute a tensor o
ver certain range
.
* \brief A Compute op that compute a tensor o
n certain domain
.
*/
*/
class
ComputeOpNode
:
public
OperationNode
{
class
ComputeOpNode
:
public
OperationNode
{
public
:
public
:
/*! \brief iter-Var over the dimensions */
Array
<
Var
>
dim_var
;
/*! \brief the compute expression */
/*! \brief the compute expression */
Expr
body
;
Expr
body
;
/*! \brief constructor */
/*! \brief constructor */
...
@@ -58,12 +57,12 @@ class ComputeOpNode : public OperationNode {
...
@@ -58,12 +57,12 @@ class ComputeOpNode : public OperationNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"
iter_var"
,
&
iter
_var
);
v
->
Visit
(
"
dim_var"
,
&
dim
_var
);
v
->
Visit
(
"body"
,
&
body
);
v
->
Visit
(
"body"
,
&
body
);
}
}
static
Operation
make
(
Domain
domain
,
static
Operation
make
(
Domain
domain
,
std
::
string
name
,
std
::
string
name
,
Array
<
Var
>
iter
_var
,
Array
<
Var
>
dim
_var
,
Expr
body
);
Expr
body
);
};
};
...
...
include/tvm/schedule.h
View file @
605813e4
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include <string>
#include <string>
#include "./base.h"
#include "./base.h"
#include "./split.h"
#include "./split.h"
#include "./
tensor
.h"
#include "./
operation
.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -30,7 +30,7 @@ class Schedule : public NodeRef {
...
@@ -30,7 +30,7 @@ class Schedule : public NodeRef {
public
:
public
:
Schedule
()
{}
Schedule
()
{}
explicit
Schedule
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
explicit
Schedule
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
Schedule
(
Tensor
tensor
,
std
::
string
scope
);
Schedule
(
Operation
op
,
std
::
string
scope
);
/*!
/*!
* \brief access the internal node container
* \brief access the internal node container
* \return the pointer to the internal node container
* \return the pointer to the internal node container
...
@@ -77,11 +77,11 @@ class AttachSpecNode : public Node {
...
@@ -77,11 +77,11 @@ class AttachSpecNode : public Node {
/*! \brief represents the schedule of the tensor */
/*! \brief represents the schedule of the tensor */
class
ScheduleNode
:
public
Node
{
class
ScheduleNode
:
public
Node
{
public
:
public
:
/*! \brief T
ensor
to be scheduled */
/*! \brief T
he operation
to be scheduled */
Tensor
tensor
;
Operation
op
;
/*! \brief The thread scope level of the schedule */
/*! \brief The thread scope level of the schedule */
std
::
string
scope
;
std
::
string
scope
;
/*! \brief Splits over
domains or r
domains */
/*! \brief Splits over
iteration
domains */
Array
<
Split
>
splits
;
Array
<
Split
>
splits
;
/*! \brief attach specifications */
/*! \brief attach specifications */
Array
<
AttachSpec
>
attachs
;
Array
<
AttachSpec
>
attachs
;
...
@@ -90,7 +90,7 @@ class ScheduleNode : public Node {
...
@@ -90,7 +90,7 @@ class ScheduleNode : public Node {
}
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"
tensor"
,
&
tensor
);
v
->
Visit
(
"
op"
,
&
op
);
v
->
Visit
(
"splits"
,
&
splits
);
v
->
Visit
(
"splits"
,
&
splits
);
v
->
Visit
(
"attachs"
,
&
attachs
);
v
->
Visit
(
"attachs"
,
&
attachs
);
}
}
...
...
include/tvm/split.h
View file @
605813e4
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#define TVM_SPLIT_H_
#define TVM_SPLIT_H_
#include "./base.h"
#include "./base.h"
#include "./expr.h"
#include "./domain.h"
#include "./domain.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -34,15 +35,13 @@ class Split : public NodeRef {
...
@@ -34,15 +35,13 @@ class Split : public NodeRef {
*/
*/
class
SplitNode
:
public
Node
{
class
SplitNode
:
public
Node
{
public
:
public
:
/*! \brief
whether the split is over reduction domain
*/
/*! \brief
the variable to be splitted on
*/
bool
split_over_rdom
{
false
}
;
Var
var
;
};
};
/*! \brief simple split node that splits over one dimension */
/*! \brief simple split node that splits over one dimension */
class
DimSplitNode
:
public
SplitNode
{
class
DimSplitNode
:
public
SplitNode
{
public
:
public
:
/*! \brief The dimension to split on */
int
dim_index
;
/*! \brief The factor of the split */
/*! \brief The factor of the split */
Expr
factor
;
Expr
factor
;
/*! \brief constructor */
/*! \brief constructor */
...
@@ -51,13 +50,10 @@ class DimSplitNode : public SplitNode {
...
@@ -51,13 +50,10 @@ class DimSplitNode : public SplitNode {
return
"DimSplit"
;
return
"DimSplit"
;
}
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"split_over_rdom"
,
&
split_over_rdom
);
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"dim_index"
,
&
dim_index
);
v
->
Visit
(
"factor"
,
&
factor
);
v
->
Visit
(
"factor"
,
&
factor
);
}
}
static
Split
make
(
int
dim_index
,
static
Split
make
(
Var
var
,
Expr
factor
);
Expr
factor
,
bool
over_rdom
);
};
};
// Implementations of inline functions
// Implementations of inline functions
...
...
include/tvm/tensor.h
View file @
605813e4
...
@@ -130,9 +130,9 @@ class TensorNode : public FunctionBaseNode {
...
@@ -130,9 +130,9 @@ class TensorNode : public FunctionBaseNode {
/*! \brief data type in the content of the tensor */
/*! \brief data type in the content of the tensor */
Type
dtype
;
Type
dtype
;
/*! \brief the source operation, can be None */
/*! \brief the source operation, can be None */
Operation
source_
op
;
Operation
op
;
/*! \brief the output index from source operation */
/*! \brief the output index from source operation */
int
sourc
e_index
{
0
};
int
valu
e_index
{
0
};
/*! \brief constructor */
/*! \brief constructor */
TensorNode
()
{}
TensorNode
()
{}
const
char
*
type_key
()
const
final
{
const
char
*
type_key
()
const
final
{
...
@@ -142,8 +142,8 @@ class TensorNode : public FunctionBaseNode {
...
@@ -142,8 +142,8 @@ class TensorNode : public FunctionBaseNode {
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"dtype"
,
&
dtype
);
v
->
Visit
(
"dtype"
,
&
dtype
);
v
->
Visit
(
"
source_op"
,
&
source_
op
);
v
->
Visit
(
"
op"
,
&
op
);
v
->
Visit
(
"
source_index"
,
&
sourc
e_index
);
v
->
Visit
(
"
value_index"
,
&
valu
e_index
);
}
}
const
std
::
string
&
func_name
()
const
final
{
const
std
::
string
&
func_name
()
const
final
{
return
name
;
return
name
;
...
@@ -154,8 +154,8 @@ class TensorNode : public FunctionBaseNode {
...
@@ -154,8 +154,8 @@ class TensorNode : public FunctionBaseNode {
static
Tensor
make
(
Array
<
Expr
>
shape
,
static
Tensor
make
(
Array
<
Expr
>
shape
,
std
::
string
name
,
std
::
string
name
,
Type
dtype
,
Type
dtype
,
Operation
source_
op
,
Operation
op
,
int
sourc
e_index
);
int
valu
e_index
);
};
};
// implementations
// implementations
...
...
python/tvm/function.py
View file @
605813e4
...
@@ -91,7 +91,10 @@ def compute(shape, fcompute, name="TensorCompute"):
...
@@ -91,7 +91,10 @@ def compute(shape, fcompute, name="TensorCompute"):
The created tensor
The created tensor
"""
"""
ndim
=
len
(
shape
)
ndim
=
len
(
shape
)
dim_var
=
[
Var
(
"dim_var
%
d"
%
i
)
for
i
in
range
(
ndim
)]
arg_names
=
fcompute
.
__code__
.
co_varnames
if
ndim
!=
len
(
arg_names
):
raise
ValueError
(
"fcompute do not match dimension"
)
dim_var
=
[
Var
(
x
)
for
x
in
arg_names
]
body
=
fcompute
(
*
dim_var
)
body
=
fcompute
(
*
dim_var
)
dom
=
[
Range
(
0
,
x
)
for
x
in
shape
]
dom
=
[
Range
(
0
,
x
)
for
x
in
shape
]
op_node
=
_function_internal
.
_ComputeOp
(
op_node
=
_function_internal
.
_ComputeOp
(
...
...
src/c_api/c_api_lang.cc
View file @
605813e4
...
@@ -102,7 +102,7 @@ TVM_REGISTER_API(_RDomain)
...
@@ -102,7 +102,7 @@ TVM_REGISTER_API(_RDomain)
TVM_REGISTER_API
(
_DimSplit
)
TVM_REGISTER_API
(
_DimSplit
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
DimSplitNode
::
make
(
args
.
at
(
0
),
args
.
at
(
1
)
,
args
.
at
(
2
)
);
*
ret
=
DimSplitNode
::
make
(
args
.
at
(
0
),
args
.
at
(
1
));
});
});
TVM_REGISTER_API
(
_Schedule
)
TVM_REGISTER_API
(
_Schedule
)
...
...
src/lang/operation.cc
View file @
605813e4
...
@@ -28,24 +28,24 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
...
@@ -28,24 +28,24 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
dom
.
push_back
(
Range
(
0
,
shape
[
i
]));
dom
.
push_back
(
Range
(
0
,
shape
[
i
]));
}
}
op_node
->
iter
_var
=
Array
<
Var
>
(
dim_index
);
op_node
->
dim
_var
=
Array
<
Var
>
(
dim_index
);
op_node
->
domain
=
Domain
(
dom
);
op_node
->
domain
=
Domain
(
dom
);
op_node
->
body
=
fcompute
(
op_node
->
iter
_var
);
op_node
->
body
=
fcompute
(
op_node
->
dim
_var
);
op_node
->
name
=
name
;
op_node
->
name
=
name
;
node
->
dtype
=
op_node
->
body
.
type
();
node
->
dtype
=
op_node
->
body
.
type
();
node
->
source_
op
=
Operation
(
op_node
);
node
->
op
=
Operation
(
op_node
);
node
->
sourc
e_index
=
0
;
node
->
valu
e_index
=
0
;
return
Tensor
(
node
);
return
Tensor
(
node
);
}
}
Operation
ComputeOpNode
::
make
(
Domain
domain
,
Operation
ComputeOpNode
::
make
(
Domain
domain
,
std
::
string
name
,
std
::
string
name
,
Array
<
Var
>
iter
_var
,
Array
<
Var
>
dim
_var
,
Expr
body
)
{
Expr
body
)
{
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
auto
n
=
std
::
make_shared
<
ComputeOpNode
>
();
n
->
domain
=
domain
;
n
->
domain
=
domain
;
n
->
name
=
name
;
n
->
name
=
name
;
n
->
iter_var
=
iter
_var
;
n
->
dim_var
=
dim
_var
;
n
->
body
=
body
;
n
->
body
=
body
;
return
Operation
(
n
);
return
Operation
(
n
);
}
}
...
...
src/lang/schedule.cc
View file @
605813e4
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
namespace
tvm
{
namespace
tvm
{
Schedule
::
Schedule
(
Tensor
tensor
,
std
::
string
scope
)
{
Schedule
::
Schedule
(
Operation
op
,
std
::
string
scope
)
{
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
tensor
=
tensor
;
n
->
op
=
op
;
n
->
scope
=
scope
;
n
->
scope
=
scope
;
node_
=
n
;
node_
=
n
;
}
}
...
...
src/lang/split.cc
View file @
605813e4
...
@@ -6,13 +6,11 @@
...
@@ -6,13 +6,11 @@
namespace
tvm
{
namespace
tvm
{
Split
DimSplitNode
::
make
(
int
dim_index
,
Split
DimSplitNode
::
make
(
Var
var
,
Expr
factor
,
Expr
factor
)
{
bool
over_rdom
)
{
auto
n
=
std
::
make_shared
<
DimSplitNode
>
();
auto
n
=
std
::
make_shared
<
DimSplitNode
>
();
CHECK_EQ
(
factor
.
type
().
lanes
(),
1
);
CHECK_EQ
(
factor
.
type
().
lanes
(),
1
);
n
->
split_over_rdom
=
over_rdom
;
n
->
var
=
var
;
n
->
dim_index
=
dim_index
;
n
->
factor
=
factor
;
n
->
factor
=
factor
;
return
Split
(
n
);
return
Split
(
n
);
}
}
...
...
src/lang/tensor.cc
View file @
605813e4
...
@@ -30,14 +30,14 @@ Expr Tensor::operator()(Array<Expr> indices) const {
...
@@ -30,14 +30,14 @@ Expr Tensor::operator()(Array<Expr> indices) const {
Tensor
TensorNode
::
make
(
Array
<
Expr
>
shape
,
Tensor
TensorNode
::
make
(
Array
<
Expr
>
shape
,
std
::
string
name
,
std
::
string
name
,
Type
dtype
,
Type
dtype
,
Operation
source_
op
,
Operation
op
,
int
sourc
e_index
)
{
int
valu
e_index
)
{
auto
n
=
std
::
make_shared
<
TensorNode
>
();
auto
n
=
std
::
make_shared
<
TensorNode
>
();
n
->
shape
=
shape
;
n
->
shape
=
shape
;
n
->
name
=
name
;
n
->
name
=
name
;
n
->
dtype
=
dtype
;
n
->
dtype
=
dtype
;
n
->
source_op
=
source_
op
;
n
->
op
=
op
;
n
->
source_index
=
sourc
e_index
;
n
->
value_index
=
valu
e_index
;
return
Tensor
(
n
);
return
Tensor
(
n
);
}
}
...
...
tests/python/test_inline.py
View file @
605813e4
...
@@ -7,7 +7,7 @@ def test_inline():
...
@@ -7,7 +7,7 @@ def test_inline():
X
=
T
(
100
)
X
=
T
(
100
)
stmt
=
tvm
.
make
.
Evaluate
(
T
(
10
)
+
11
*
T
(
100
))
stmt
=
tvm
.
make
.
Evaluate
(
T
(
10
)
+
11
*
T
(
100
))
stmt
=
tvm
.
ir_pass
.
Inline
(
stmt
=
tvm
.
ir_pass
.
Inline
(
T
,
T
.
source_op
.
iter_var
,
T
.
source_
op
.
body
,
stmt
)
T
,
T
.
op
.
dim_var
,
T
.
op
.
body
,
stmt
)
print
(
stmt
)
print
(
stmt
)
assert
(
tvm
.
ir_pass
.
VerifySSA
(
stmt
))
assert
(
tvm
.
ir_pass
.
VerifySSA
(
stmt
))
...
...
tests/python/test_schedule.py
View file @
605813e4
...
@@ -9,10 +9,11 @@ def test_schedule_create():
...
@@ -9,10 +9,11 @@ def test_schedule_create():
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
sch
=
tvm
.
Schedule
(
T
,
scope
=
"shared"
)
sch
=
tvm
.
Schedule
(
T
,
scope
=
"shared"
)
tk1
=
tvm
.
Split
(
0
,
10
)
tk1
=
tvm
.
Split
(
T
.
op
.
dim_var
[
0
]
,
10
)
assert
isinstance
(
sch
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
sch
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
tk1
,
tvm
.
schedule
.
DimSplit
)
assert
isinstance
(
tk1
,
tvm
.
schedule
.
DimSplit
)
print
(
tk1
.
var
)
print
(
sch
.
scope
)
print
(
sch
.
scope
)
print
(
sch
.
attachs
)
print
(
sch
.
attachs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment