Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0c72ca97
Commit
0c72ca97
authored
Dec 05, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Finish schedule operation
parent
59bb0dd4
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
271 additions
and
31 deletions
+271
-31
HalideIR
+1
-1
include/tvm/schedule.h
+7
-1
python/tvm/function.py
+4
-4
python/tvm/schedule.py
+96
-3
python/tvm/tensor.py
+2
-0
src/c_api/c_api_lang.cc
+49
-0
src/c_api/c_api_registry.h
+1
-1
src/lang/expr.cc
+2
-0
src/lang/schedule.cc
+87
-7
tests/python/test_schedule.py
+22
-14
No files found.
HalideIR
@
ea1a81be
Subproject commit
29fd3defa3dbf810e52dbc2ecd3933604989dc
c8
Subproject commit
ea1a81be8baa43665f6ebd4d75d51c081283eb
c8
include/tvm/schedule.h
View file @
0c72ca97
...
@@ -50,16 +50,19 @@ class Schedule : public NodeRef {
...
@@ -50,16 +50,19 @@ class Schedule : public NodeRef {
* \brief specify the schedule to be computed at the parent schedule's scope.
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
*/
Schedule
&
compute_at
(
Schedule
parent
,
IterVar
scope
);
// NOLINT(*)
Schedule
&
compute_at
(
Schedule
parent
,
IterVar
scope
);
// NOLINT(*)
/*!
/*!
* \brief Compute the function inline, attach it at parent.
* \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
*/
Schedule
&
compute_inline
(
Schedule
parent
);
// NOLINT(*)
Schedule
&
compute_inline
(
Schedule
parent
);
// NOLINT(*)
/*!
/*!
* \brief Compute the function at root, attach it to its parent.
* \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
*/
Schedule
&
compute_root
(
Schedule
parent
);
// NOLINT(*)
Schedule
&
compute_root
(
Schedule
parent
);
// NOLINT(*)
/*!
/*!
...
@@ -68,7 +71,7 @@ class Schedule : public NodeRef {
...
@@ -68,7 +71,7 @@ class Schedule : public NodeRef {
* \param p_outer The result outer domain
* \param p_outer The result outer domain
* \param p_inner The result inner domain.
* \param p_inner The result inner domain.
* \param factor The split factor of the loop.
* \param factor The split factor of the loop.
* \
param outer The generated
* \
return reference to self.
*/
*/
Schedule
&
split
(
IterVar
parent
,
IterVar
*
p_outer
,
IterVar
*
p_inner
,
Expr
factor
);
// NOLINT(*)
Schedule
&
split
(
IterVar
parent
,
IterVar
*
p_outer
,
IterVar
*
p_inner
,
Expr
factor
);
// NOLINT(*)
/*!
/*!
...
@@ -80,6 +83,7 @@ class Schedule : public NodeRef {
...
@@ -80,6 +83,7 @@ class Schedule : public NodeRef {
* \param p_inner The result inner domain.
* \param p_inner The result inner domain.
* \param factor Optional, the factor of the split,
* \param factor Optional, the factor of the split,
* factor must be provided such that factor * outer.extent >= parent.extent.
* factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self.
*/
*/
Schedule
&
split
(
IterVar
parent
,
IterVar
outer
,
IterVar
*
p_inner
,
Expr
factor
=
Expr
());
// NOLINT(*)
Schedule
&
split
(
IterVar
parent
,
IterVar
outer
,
IterVar
*
p_inner
,
Expr
factor
=
Expr
());
// NOLINT(*)
/*!
/*!
...
@@ -87,11 +91,13 @@ class Schedule : public NodeRef {
...
@@ -87,11 +91,13 @@ class Schedule : public NodeRef {
* \param inner The inner domain to be fused
* \param inner The inner domain to be fused
* \param outer The outer domain to be fused.
* \param outer The outer domain to be fused.
* \param p_target The result target domain.
* \param p_target The result target domain.
* \return reference to self.
*/
*/
Schedule
&
fuse
(
IterVar
inner
,
IterVar
outer
,
IterVar
*
p_target
);
// NOLINT(*)
Schedule
&
fuse
(
IterVar
inner
,
IterVar
outer
,
IterVar
*
p_target
);
// NOLINT(*)
/*!
/*!
* \brief Reorder the iteration
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \param order The order of iteration variable.
* \return reference to self.
*/
*/
Schedule
&
reorder
(
const
Array
<
IterVar
>&
order
);
// NOLINT(*)
Schedule
&
reorder
(
const
Array
<
IterVar
>&
order
);
// NOLINT(*)
};
};
...
...
python/tvm/function.py
View file @
0c72ca97
...
@@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"):
...
@@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"):
tensor: tensor.Tensor
tensor: tensor.Tensor
The created tensor
The created tensor
"""
"""
if
isinstance
(
shape
,
_expr
.
Expr
):
shape
=
(
shape
,
)
ndim
=
len
(
shape
)
ndim
=
len
(
shape
)
arg_names
=
fcompute
.
__code__
.
co_varnames
arg_names
=
fcompute
.
__code__
.
co_varnames
if
ndim
!=
len
(
arg_names
):
if
ndim
!=
len
(
arg_names
):
...
@@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"):
...
@@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"):
dim_var
=
[
IterVar
((
0
,
s
),
x
)
for
x
,
s
in
zip
(
arg_names
,
shape
)]
dim_var
=
[
IterVar
((
0
,
s
),
x
)
for
x
,
s
in
zip
(
arg_names
,
shape
)]
body
=
fcompute
(
*
[
v
.
var
for
v
in
dim_var
])
body
=
fcompute
(
*
[
v
.
var
for
v
in
dim_var
])
body
=
convert
(
body
)
op_node
=
_function_internal
.
_ComputeOp
(
op_node
=
_function_internal
.
_ComputeOp
(
name
,
dim_var
,
body
)
name
,
dim_var
,
body
)
return
_function_internal
.
_Tensor
(
return
_function_internal
.
_Tensor
(
...
@@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"):
...
@@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"):
return
_function_internal
.
_Schedule
(
tensor
,
scope
)
return
_function_internal
.
_Schedule
(
tensor
,
scope
)
def
Split
(
dim
,
factor
,
over_rdom
=
False
):
return
_function_internal
.
_DimSplit
(
dim
,
factor
,
over_rdom
)
_init_function_module
(
"tvm"
)
_init_function_module
(
"tvm"
)
python/tvm/schedule.py
View file @
0c72ca97
...
@@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node
...
@@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node
from
.
import
_function_internal
from
.
import
_function_internal
@register_node
@register_node
class
Dim
Split
(
NodeBase
):
class
Split
(
NodeBase
):
pass
pass
@register_node
@register_node
class
AttachSpec
(
NodeBase
):
class
Fuse
(
NodeBase
):
pass
pass
@register_node
@register_node
class
Schedule
(
NodeBase
):
class
Schedule
(
NodeBase
):
pass
def
split
(
self
,
parent
,
factor
=
None
,
outer
=
None
):
"""Split the schedule either by factor providing outer scope, or both
Parameters
----------
parent : IterVar
The parent iter var.
factor : Expr, optional
The splitting factor
outer : IterVar, optional
The outer split variable
Returns
-------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
"""
if
outer
is
not
None
:
if
outer
.
thread_tag
==
''
:
raise
ValueError
(
"split by outer must have special thread_tag"
)
if
outer
.
dom
is
None
:
raise
ValueError
(
"split by outer must have specified domain"
)
inner
=
_function_internal
.
_ScheduleSplitByOuter
(
self
,
parent
,
outer
,
factor
)
else
:
if
factor
is
None
:
raise
ValueError
(
"either outer or factor need to be provided"
)
outer
,
inner
=
_function_internal
.
_ScheduleSplitByFactor
(
self
,
parent
,
factor
)
return
outer
,
inner
def
fuse
(
self
,
inner
,
outer
):
"""Fuse inner and outer to a single iteration variable.
Parameters
----------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
Returns
-------
inner : IterVar
The fused variable of iteration.
"""
return
_function_internal
.
_ScheduleFuse
(
self
,
inner
,
outer
)
def
compute_at
(
self
,
parent
,
scope
):
"""Attach the schedule at parent's scope
Parameters
----------
parent : Schedule
The parent schedule
scope : IterVar
The loop scope t be attached to.
"""
_function_internal
.
_ScheduleComputeAt
(
self
,
parent
,
scope
)
def
compute_inline
(
self
,
parent
):
"""Attach the schedule at parent, and mark it as inline
Parameters
----------
parent : Schedule
The parent schedule
"""
_function_internal
.
_ScheduleComputeInline
(
self
,
parent
)
def
compute_root
(
self
,
parent
):
"""Attach the schedule at parent, and mark it as root
Parameters
----------
parent : Schedule
The parent schedule
"""
_function_internal
.
_ScheduleComputeInline
(
self
,
parent
)
def
reorder
(
self
,
*
args
):
"""reorder the arguments in the specified order.
Parameters
----------
args : list of IterVar
The order to be ordered
"""
_function_internal
.
_ScheduleReorder
(
self
,
args
)
python/tvm/tensor.py
View file @
0c72ca97
...
@@ -7,6 +7,8 @@ from . import expr as _expr
...
@@ -7,6 +7,8 @@ from . import expr as _expr
class
TensorSlice
(
SliceBase
,
_expr
.
ExprOp
):
class
TensorSlice
(
SliceBase
,
_expr
.
ExprOp
):
"""Auxiliary data structure for enable slicing syntax from tensor."""
"""Auxiliary data structure for enable slicing syntax from tensor."""
def
__init__
(
self
,
tensor
,
indices
):
def
__init__
(
self
,
tensor
,
indices
):
if
not
isinstance
(
indices
,
tuple
):
indices
=
(
indices
,)
self
.
tensor
=
tensor
self
.
tensor
=
tensor
self
.
indices
=
indices
self
.
indices
=
indices
...
...
src/c_api/c_api_lang.cc
View file @
0c72ca97
...
@@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule)
...
@@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule)
*
ret
=
Schedule
(
args
.
at
(
0
),
args
.
at
(
1
));
*
ret
=
Schedule
(
args
.
at
(
0
),
args
.
at
(
1
));
});
});
TVM_REGISTER_API
(
_ScheduleSplitByFactor
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
IterVar
outer
,
inner
;
args
.
at
(
0
).
operator
Schedule
()
.
split
(
args
.
at
(
1
),
&
outer
,
&
inner
,
args
.
at
(
2
));
*
ret
=
Array
<
IterVar
>
({
outer
,
inner
});
});
TVM_REGISTER_API
(
_ScheduleSplitByOuter
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
IterVar
inner
;
args
.
at
(
0
).
operator
Schedule
()
.
split
(
args
.
at
(
1
),
args
.
at
(
2
),
&
inner
,
args
.
at
(
3
));
*
ret
=
inner
;
});
TVM_REGISTER_API
(
_ScheduleFuse
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
IterVar
fused
;
args
.
at
(
0
).
operator
Schedule
()
.
split
(
args
.
at
(
1
),
args
.
at
(
2
),
&
fused
);
*
ret
=
fused
;
});
TVM_REGISTER_API
(
_ScheduleComputeAt
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
args
.
at
(
0
).
operator
Schedule
()
.
compute_at
(
args
.
at
(
1
),
args
.
at
(
2
));
});
TVM_REGISTER_API
(
_ScheduleComputeInline
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
args
.
at
(
0
).
operator
Schedule
()
.
compute_inline
(
args
.
at
(
1
));
});
TVM_REGISTER_API
(
_ScheduleComputeRoot
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
args
.
at
(
0
).
operator
Schedule
()
.
compute_root
(
args
.
at
(
1
));
});
TVM_REGISTER_API
(
_ScheduleReorder
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
args
.
at
(
0
).
operator
Schedule
()
.
reorder
(
args
.
at
(
1
));
});
}
// namespace tvm
}
// namespace tvm
src/c_api/c_api_registry.h
View file @
0c72ca97
...
@@ -115,7 +115,7 @@ class APIVariantValue {
...
@@ -115,7 +115,7 @@ class APIVariantValue {
CHECK_EQ
(
type_id
,
kNodeHandle
);
CHECK_EQ
(
type_id
,
kNodeHandle
);
// use dynamic RTTI for safety
// use dynamic RTTI for safety
CHECK
(
dynamic_cast
<
typename
T
::
ContainerType
*>
(
sptr
.
get
()))
CHECK
(
dynamic_cast
<
typename
T
::
ContainerType
*>
(
sptr
.
get
()))
<<
"wrong type specified
"
;
<<
"wrong type specified
, expected "
<<
typeid
(
typename
T
::
ContainerType
).
name
()
;
return
T
(
sptr
);
return
T
(
sptr
);
}
}
inline
operator
Expr
()
const
{
inline
operator
Expr
()
const
{
...
...
src/lang/expr.cc
View file @
0c72ca97
...
@@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
if
(
op
->
var
->
name_hint
.
length
()
!=
0
)
{
if
(
op
->
var
->
name_hint
.
length
()
!=
0
)
{
p
->
stream
<<
op
->
var
->
name_hint
<<
", "
;
p
->
stream
<<
op
->
var
->
name_hint
<<
", "
;
}
}
if
(
op
->
dom
.
defined
())
{
p
->
stream
<<
op
->
dom
;
p
->
stream
<<
op
->
dom
;
}
if
(
op
->
thread_tag
.
length
()
!=
0
)
{
if
(
op
->
thread_tag
.
length
()
!=
0
)
{
p
->
stream
<<
", "
<<
op
->
thread_tag
;
p
->
stream
<<
", "
<<
op
->
thread_tag
;
}
}
...
...
src/lang/schedule.cc
View file @
0c72ca97
...
@@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
...
@@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
return
array_node
->
data
.
size
();
return
array_node
->
data
.
size
();
}
}
size_t
FindLeafVar
(
ArrayNode
*
all_vars
,
ArrayNode
*
const
IterVar
&
v
)
{
size_t
FindLeafVar
(
ArrayNode
*
all_vars
,
ArrayNode
*
leaf_vars
,
const
IterVar
&
v
)
{
size_t
pos
=
Find
(
leaf_iter_vars
,
parent
);
size_t
pos
=
FindIterVar
(
leaf_vars
,
v
);
if
(
pos
<
leaf_vars
->
data
.
size
())
return
pos
;
if
(
FindIterVar
(
all_vars
,
v
)
<
all_vars
->
data
.
size
())
{
LOG
(
FATAL
)
<<
"Operate on iter var "
<<
v
<<
"that has already been splitted"
;
}
else
{
LOG
(
FATAL
)
<<
"Operate on iter var "
<<
v
<<
"that is not part of the schedule"
;
}
return
0
;
}
}
void
Split
(
ScheduleNode
*
self
,
IterVar
parent
,
IterVar
outer
,
IterVar
inner
,
Expr
factor
)
{
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
size_t
pos
=
FindLeafVar
(
all_vars
,
leaf_vars
,
parent
);
self
->
relations
.
push_back
(
SplitNode
::
make
(
parent
,
outer
,
inner
,
factor
));
// add vars to all vars
all_vars
->
data
.
push_back
(
outer
.
node_
);
all_vars
->
data
.
push_back
(
inner
.
node_
);
// replace the position.
leaf_vars
->
data
.
erase
(
leaf_vars
->
data
.
begin
()
+
pos
);
leaf_vars
->
data
.
insert
(
leaf_vars
->
data
.
begin
()
+
pos
,
inner
.
node_
);
leaf_vars
->
data
.
insert
(
leaf_vars
->
data
.
begin
()
+
pos
,
outer
.
node_
);
}
}
}
// namespace
Schedule
::
Schedule
(
Operation
op
,
std
::
string
scope
)
{
Schedule
::
Schedule
(
Operation
op
,
std
::
string
scope
)
{
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
op
=
op
;
n
->
op
=
op
;
...
@@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
...
@@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
CHECK_EQ
((
*
this
)
->
attach_type
,
kNone
);
CHECK_EQ
((
*
this
)
->
attach_type
,
kNone
);
(
*
this
)
->
attach_type
=
kScope
;
(
*
this
)
->
attach_type
=
kScope
;
(
*
this
)
->
attach_parent
=
scope
;
(
*
this
)
->
attach_parent
=
scope
;
bool
found
=
false
;
for
(
size_t
i
=
0
;
i
<
parent
->
leaf_iter_vars
.
size
();
++
i
)
{
if
(
scope
==
parent
->
leaf_iter_vars
[
i
])
{
found
=
true
;
break
;
}
}
CHECK
(
found
)
<<
"Cannot compute at a iteration variable that is not part of parent leaf vars"
;
parent
->
children
.
push_back
(
*
this
);
parent
->
children
.
push_back
(
*
this
);
return
*
this
;
return
*
this
;
}
}
...
@@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*)
...
@@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*)
Schedule
&
Schedule
::
split
(
Schedule
&
Schedule
::
split
(
IterVar
parent
,
IterVar
*
p_outer
,
IterVar
*
p_inner
,
Expr
factor
)
{
// NOLINT(*)
IterVar
parent
,
IterVar
*
p_outer
,
IterVar
*
p_inner
,
Expr
factor
)
{
// NOLINT(*)
ScheduleNode
*
self
=
operator
->
();
// place holder for the splitted results.
ArrayNode
*
leaf_iter_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
IterVar
outer
(
Range
(),
parent
->
var
->
name_hint
+
".outer"
);
IterVar
inner
(
Range
(),
parent
->
var
->
name_hint
+
".inner"
);
*
p_outer
=
outer
;
*
p_inner
=
inner
;
Split
(
operator
->
(),
parent
,
outer
,
inner
,
factor
);
return
*
this
;
}
Schedule
&
Schedule
::
split
(
IterVar
parent
,
IterVar
outer
,
IterVar
*
p_inner
,
Expr
factor
)
{
// NOLINT(*)
// place holder for the splitted results.
IterVar
inner
(
Range
(),
parent
->
var
->
name_hint
+
".inner"
);
*
p_inner
=
inner
;
Split
(
operator
->
(),
parent
,
outer
,
inner
,
factor
);
CHECK
(
pos
!=
leaf_iter_vars
->
data
.
size
())
return
*
this
;
<<
"Cannot find IterVar "
<<
parent
<<
" in the active leaf vars"
}
<<
" this means "
Schedule
&
Schedule
::
fuse
(
IterVar
inner
,
IterVar
outer
,
IterVar
*
p_target
)
{
// NOLINT(*)
IterVar
fused
(
Range
(),
outer
->
var
->
name_hint
+
"."
+
inner
->
var
->
name_hint
+
".fused"
);
ScheduleNode
*
self
=
operator
->
();
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
self
->
relations
.
push_back
(
FuseNode
::
make
(
inner
,
outer
,
fused
));
all_vars
->
data
.
push_back
(
fused
.
node_
);
size_t
pos_inner
=
FindLeafVar
(
all_vars
,
leaf_vars
,
inner
);
size_t
pos_outer
=
FindLeafVar
(
all_vars
,
leaf_vars
,
outer
);
CHECK_EQ
(
pos_inner
,
pos_outer
+
1
)
<<
"Can only fuse iterations that are consecutive between each other"
;
leaf_vars
->
data
.
erase
(
leaf_vars
->
data
.
begin
()
+
pos_outer
,
leaf_vars
->
data
.
begin
()
+
pos_inner
);
leaf_vars
->
data
.
insert
(
leaf_vars
->
data
.
begin
()
+
pos_outer
,
fused
.
node_
);
return
*
this
;
return
*
this
;
}
}
Schedule
&
Schedule
::
reorder
(
const
Array
<
IterVar
>&
order
)
{
// NOLINT(*)
ScheduleNode
*
self
=
operator
->
();
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
std
::
vector
<
size_t
>
pos
;
for
(
size_t
i
=
0
;
i
<
order
.
size
();
++
i
)
{
pos
.
push_back
(
FindLeafVar
(
all_vars
,
leaf_vars
,
order
[
i
]));
}
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
temp
;
for
(
size_t
i
=
0
;
i
<
pos
.
size
();
++
i
)
{
temp
.
emplace_back
(
leaf_vars
->
data
[
pos
[
i
]]);
}
std
::
sort
(
pos
.
begin
(),
pos
.
end
());
for
(
size_t
i
=
0
;
i
<
pos
.
size
();
++
i
)
{
leaf_vars
->
data
[
pos
[
i
]]
=
temp
[
i
];
}
return
*
this
;
}
IterVarRelation
SplitNode
::
make
(
IterVarRelation
SplitNode
::
make
(
IterVar
parent
,
IterVar
outer
,
IterVar
parent
,
IterVar
outer
,
...
...
tests/python/test_schedule.py
View file @
0c72ca97
...
@@ -6,28 +6,36 @@ def test_schedule_create():
...
@@ -6,28 +6,36 @@ def test_schedule_create():
l
=
tvm
.
Var
(
'l'
)
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
AA
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
])
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
Tsch
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
Asch
=
tvm
.
Schedule
(
A
.
op
)
sch_A
=
tvm
.
Schedule
(
AA
.
op
,
scope
=
"global"
)
T
.
op
.
xo
,
xi
=
sch_T
.
split
(
T
.
op
.
dim_var
[
0
],
factor
=
10
)
xi1
,
xi2
=
sch_T
.
split
(
xi
,
factor
=
2
)
sch_A
.
compute_at
(
sch_T
,
xi1
)
xo
,
xi
=
sch_A
.
split
(
AA
.
op
.
dim_var
[
0
],
factor
=
10
)
xo
,
xi
=
sch
.
split
(
sch
.
dim_var
[
0
],
factor
)
sch_T
.
reorder
(
xi2
,
xi1
)
Asch
.
compute_at
(
Tsch
,
xi
)
assert
T
.
op
.
dim_var
[
1
]
in
sch_T
.
leaf_iter_vars
xf
=
sch
.
fuse
(
xo
,
xi
)
def
test_reorder
():
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
tk1
=
tvm
.
Split
(
T
.
op
.
dim_var
[
0
],
10
)
T
=
tvm
.
compute
(
m
,
lambda
i
:
A
[
i
+
1
])
assert
isinstance
(
sch
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
tk1
,
tvm
.
schedule
.
DimSplit
)
print
(
tk1
.
var
)
sch_T
=
tvm
.
Schedule
(
T
.
op
,
scope
=
"shared"
)
print
(
sch
.
scope
)
xo
,
xi
=
sch_T
.
split
(
T
.
op
.
dim_var
[
0
],
factor
=
10
)
print
(
sch
.
attachs
)
xi1
,
xi2
=
sch_T
.
split
(
xi
,
factor
=
2
)
order
=
(
xi2
,
xi1
,
xo
)
assert
tuple
(
sch_T
.
leaf_iter_vars
)
!=
order
sch_T
.
reorder
(
*
order
)
assert
tuple
(
sch_T
.
leaf_iter_vars
)
==
order
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_schedule_create
()
test_schedule_create
()
test_reorder
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment