Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5d2ccd66
Commit
5d2ccd66
authored
6 years ago
by
Tianqi Chen
Committed by
GitHub
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Fuse support for 0 rank tensor (#1328)
parent
0134fabb
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
115 additions
and
15 deletions
+115
-15
include/tvm/schedule.h
+36
-3
python/tvm/schedule.py
+7
-4
src/api/api_lang.cc
+1
-1
src/schedule/message_passing.cc
+11
-0
src/schedule/schedule_lang.cc
+37
-1
tests/python/integration/test_ewise.py
+4
-5
tests/python/unittest/test_lang_schedule.py
+14
-0
topi/tests/python/test_topi_broadcast.py
+5
-1
No files found.
include/tvm/schedule.h
View file @
5d2ccd66
...
...
@@ -130,6 +130,20 @@ class Stage : public NodeRef {
*/
EXPORT
Stage
&
fuse
(
IterVar
outer
,
IterVar
inner
,
IterVar
*
p_target
);
// NOLINT(*)
/*!
* \brief Fuse all the axes together into a single axis.
*
* \param axes All the axes to be fused.
* \param p_target The result target domain.
*
* \note axes can be an empty array,
* in that case, a singleton itervar is created and
* inserted to the outermost loop.
* The fuse of empty array is used to support zero-dimension tensors.
*
* \return reference to self.
*/
EXPORT
Stage
&
fuse
(
const
Array
<
IterVar
>&
axes
,
IterVar
*
p_target
);
// NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
...
...
@@ -151,9 +165,9 @@ class Stage : public NodeRef {
* \return reference to self.
*/
EXPORT
Stage
&
tile
(
IterVar
x_parent
,
IterVar
y_parent
,
// NOLINT(*)
Expr
x_factor
,
Expr
y_factor
,
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
);
Expr
x_factor
,
Expr
y_factor
,
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
...
...
@@ -674,6 +688,25 @@ class RebaseNode : public IterVarRelationNode {
};
/*!
* \brief Singleton iterator [0, 1)
*/
class
SingletonNode
:
public
IterVarRelationNode
{
public
:
/*! \brief The singleton iterator */
IterVar
iter
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"iter"
,
&
iter
);
}
static
IterVarRelation
make
(
IterVar
iter
);
static
constexpr
const
char
*
_type_key
=
"Singleton"
;
TVM_DECLARE_NODE_TYPE_INFO
(
SingletonNode
,
IterVarRelationNode
);
};
// implementations
inline
const
StageNode
*
Stage
::
operator
->
()
const
{
return
static_cast
<
const
StageNode
*>
(
node_
.
get
());
...
...
This diff is collapsed.
Click to expand it.
python/tvm/schedule.py
View file @
5d2ccd66
...
...
@@ -153,6 +153,12 @@ class Fuse(NodeBase):
@register_node
class
Singleton
(
NodeBase
):
"""Singleton axis."""
pass
@register_node
class
IterVar
(
NodeBase
,
_expr
.
ExprOp
):
"""Represent iteration variable.
...
...
@@ -380,10 +386,7 @@ class Stage(NodeBase):
fused : IterVar
The fused variable of iteration.
"""
assert
len
(
args
)
>=
1
,
"Length of the arguments must be >=1 for fuse."
fused
=
args
[
0
]
for
i
in
range
(
1
,
len
(
args
)):
fused
=
_api_internal
.
_StageFuse
(
self
,
fused
,
args
[
i
])
fused
=
_api_internal
.
_StageFuse
(
self
,
args
)
return
fused
def
set_scope
(
self
,
scope
):
...
...
This diff is collapsed.
Click to expand it.
src/api/api_lang.cc
View file @
5d2ccd66
...
...
@@ -350,7 +350,7 @@ TVM_REGISTER_API("_StageFuse")
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
IterVar
fused
;
args
[
0
].
operator
Stage
()
.
fuse
(
args
[
1
],
args
[
2
],
&
fused
);
.
fuse
(
args
[
1
],
&
fused
);
*
ret
=
fused
;
});
...
...
This diff is collapsed.
Click to expand it.
src/schedule/message_passing.cc
View file @
5d2ccd66
...
...
@@ -82,6 +82,8 @@ void PassDownDomain(const Stage& stage,
Update
(
p_state
,
r
->
rebased
,
Range
::
make_by_min_extent
(
0
,
state
.
at
(
r
->
parent
)
->
extent
));
}
else
if
(
const
SingletonNode
*
s
=
rel
.
as
<
SingletonNode
>
())
{
Update
(
p_state
,
s
->
iter
,
Range
::
make_by_min_extent
(
0
,
1
));
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -147,6 +149,7 @@ void PassUpIndex(const Stage& stage,
}
else
{
state
[
s
->
parent
]
=
value
;
}
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -192,6 +195,8 @@ void PassDownIndex(const Stage& stage,
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
CHECK
(
is_zero
(
parent_min
));
state
[
s
->
rebased
]
=
value
;
}
else
if
(
const
SingletonNode
*
s
=
rel
.
as
<
SingletonNode
>
())
{
state
[
s
->
iter
]
=
make_zero
(
s
->
iter
->
var
.
type
());
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -296,6 +301,7 @@ void PassUpDomain(const Stage& stage,
state
.
at
(
r
->
rebased
),
&
parent
);
state
[
r
->
parent
]
=
parent
;
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -344,6 +350,7 @@ void PassUpBitMaskOr(const Stage& stage,
}
else
{
state
[
s
->
parent
]
|=
state
[
s
->
rebased
];
}
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -390,6 +397,8 @@ void PassDownBitMaskOr(const Stage& stage,
}
else
{
state
[
s
->
rebased
]
|=
state
.
at
(
s
->
parent
);
}
}
else
if
(
const
SingletonNode
*
s
=
rel
.
as
<
SingletonNode
>
())
{
state
[
s
->
iter
]
=
0
;
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
@@ -438,6 +447,8 @@ void PassUpBoundCheck(const Stage& s,
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
state
[
s
->
parent
]
=
state
.
at
(
s
->
rebased
);
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
// nop
}
else
{
LOG
(
FATAL
)
<<
"unknown relation type"
;
}
...
...
This diff is collapsed.
Click to expand it.
src/schedule/schedule_lang.cc
View file @
5d2ccd66
...
...
@@ -237,7 +237,6 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
IterVar
fused
=
IterVarNode
::
make
(
Range
(),
Var
(
fused_name
,
outer
->
var
.
type
()),
iter_type
);
*
p_target
=
fused
;
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
...
...
@@ -255,6 +254,31 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
leaf_vars
->
data
.
begin
()
+
pos_inner
+
1
);
leaf_vars
->
data
.
insert
(
leaf_vars
->
data
.
begin
()
+
pos_outer
,
fused
.
node_
);
*
p_target
=
fused
;
return
*
this
;
}
Stage
&
Stage
::
fuse
(
const
Array
<
IterVar
>&
axes
,
IterVar
*
p_target
)
{
// NOLINT(*)
if
(
axes
.
size
()
!=
0
)
{
IterVar
fused
=
axes
[
0
];
for
(
size_t
i
=
1
;
i
<
axes
.
size
();
++
i
)
{
this
->
fuse
(
fused
,
axes
[
i
],
&
fused
);
}
*
p_target
=
std
::
move
(
fused
);
}
else
{
StageNode
*
self
=
operator
->
();
// special handle fuse empty array.
// insert at the outer most loop
IterVar
singleton
=
IterVarNode
::
make
(
Range
::
make_by_min_extent
(
0
,
1
),
Var
(
"singleton"
,
Int
(
32
)),
kDataPar
);
self
->
relations
.
push_back
(
SingletonNode
::
make
(
singleton
));
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
leaf_vars
=
self
->
leaf_iter_vars
.
CopyOnWrite
();
all_vars
->
data
.
push_back
(
singleton
.
node_
);
leaf_vars
->
data
.
insert
(
leaf_vars
->
data
.
begin
(),
singleton
.
node_
);
*
p_target
=
singleton
;
}
return
*
this
;
}
...
...
@@ -732,11 +756,18 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return
IterVarRelation
(
n
);
}
IterVarRelation
SingletonNode
::
make
(
IterVar
iter
)
{
auto
n
=
std
::
make_shared
<
SingletonNode
>
();
n
->
iter
=
iter
;
return
IterVarRelation
(
n
);
}
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
IterVarAttrNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
RebaseNode
);
TVM_REGISTER_NODE_TYPE
(
SingletonNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
// Printer
...
...
@@ -778,6 +809,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p
->
print
(
op
->
rebased
);
p
->
stream
<<
')'
;
})
.
set_dispatch
<
SingletonNode
>
([](
const
SingletonNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"singleton("
;
p
->
print
(
op
->
iter
);
p
->
stream
<<
')'
;
})
.
set_dispatch
<
ScheduleNode
>
([](
const
ScheduleNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"schedule("
<<
op
<<
")"
;
});
...
...
This diff is collapsed.
Click to expand it.
tests/python/integration/test_ewise.py
View file @
5d2ccd66
...
...
@@ -44,10 +44,10 @@ def test_multiple_cache_write():
n
=
tvm
.
convert
(
1024
)
A0
=
tvm
.
placeholder
((
n
,),
name
=
'A0'
,
dtype
=
"float32"
)
A1
=
tvm
.
placeholder
((
n
,),
name
=
'A1'
,
dtype
=
"float32"
)
B0
,
B1
=
tvm
.
compute
((
n
,),
lambda
*
i
:
(
A0
(
*
i
)
+
A1
(
*
i
),
A0
(
*
i
)
*
A1
(
*
i
)),
B0
,
B1
=
tvm
.
compute
((
n
,),
lambda
*
i
:
(
A0
(
*
i
)
+
A1
(
*
i
),
A0
(
*
i
)
*
A1
(
*
i
)),
name
=
'B'
)
C
=
tvm
.
compute
((
n
,),
lambda
*
i
:
B0
(
*
i
)
+
B1
(
*
i
),
C
=
tvm
.
compute
((
n
,),
lambda
*
i
:
B0
(
*
i
)
+
B1
(
*
i
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
# create iter var and assign them tags.
...
...
@@ -76,7 +76,7 @@ def test_multiple_cache_write():
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
func
(
a0
,
a1
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a0
.
asnumpy
()
+
a1
.
asnumpy
()
+
(
a0
.
asnumpy
()
*
a1
.
asnumpy
()),
c
.
asnumpy
(),
a0
.
asnumpy
()
+
a1
.
asnumpy
()
+
(
a0
.
asnumpy
()
*
a1
.
asnumpy
()),
rtol
=
1e-5
)
check_device
(
"cuda"
,
"llvm"
)
...
...
@@ -235,7 +235,6 @@ def try_warp_memory():
f
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
3
,
rtol
=
1e-6
)
check_device
(
"cuda"
)
...
...
This diff is collapsed.
Click to expand it.
tests/python/unittest/test_lang_schedule.py
View file @
5d2ccd66
...
...
@@ -84,6 +84,19 @@ def test_fuse():
assert
any
(
isinstance
(
x
,
tvm
.
schedule
.
Fuse
)
for
x
in
s
[
T
]
.
relations
)
assert
tuple
(
s
[
T
]
.
leaf_iter_vars
)
==
(
fused
,
xi
,
yi
)
def
test_singleton
():
A
=
tvm
.
placeholder
((),
name
=
'A'
)
T
=
tvm
.
compute
((),
lambda
:
A
()
+
1
)
s
=
tvm
.
create_schedule
(
T
.
op
)
fused
=
s
[
T
]
.
fuse
()
assert
any
(
isinstance
(
x
,
tvm
.
schedule
.
Singleton
)
for
x
in
s
[
T
]
.
relations
)
assert
tuple
(
s
[
T
]
.
leaf_iter_vars
)
==
(
fused
,)
dump
=
pkl
.
dumps
(
s
)
s_loaded
=
pkl
.
loads
(
dump
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
def
test_vectorize
():
m
=
tvm
.
var
(
'm'
)
n
=
tvm
.
var
(
'n'
)
...
...
@@ -174,6 +187,7 @@ def test_tensor_intrin():
if
__name__
==
"__main__"
:
test_singleton
()
test_pragma
()
test_tensor_intrin
()
test_rfactor
()
...
...
This diff is collapsed.
Click to expand it.
topi/tests/python/test_topi_broadcast.py
View file @
5d2ccd66
...
...
@@ -94,6 +94,8 @@ def test_broadcast_to():
def
test_add
():
verify_broadcast_binary_ele
(
(),
(),
topi
.
add
,
np
.
add
)
verify_broadcast_binary_ele
(
(
5
,
2
,
3
),
(
2
,
1
),
topi
.
add
,
np
.
add
)
def
test_subtract
():
...
...
@@ -114,6 +116,8 @@ def test_divide():
verify_broadcast_binary_ele
(
None
,
(
10
,),
topi
.
divide
,
np
.
divide
,
rhs_min
=
0.0001
)
verify_broadcast_binary_ele
(
(),
None
,
topi
.
divide
,
np
.
divide
,
rhs_min
=
0.0001
)
verify_broadcast_binary_ele
(
(
2
,
3
,
1
,
32
),
(
64
,
32
),
topi
.
divide
,
np
.
divide
,
rhs_min
=
0.0001
)
def
test_maximum_minmum
():
...
...
@@ -157,10 +161,10 @@ def test_shift():
if
__name__
==
"__main__"
:
test_add
()
test_shift
()
test_cmp
()
test_mod
()
test_add
()
test_subtract
()
test_multiply
()
test_divide
()
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment