Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
03b09f74
Commit
03b09f74
authored
May 15, 2017
by
Tianqi Chen
Committed by
GitHub
May 15, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Improve SSA conversion, add forbid list in loop-par (#142)
parent
867ad378
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
167 additions
and
113 deletions
+167
-113
python/tvm/build.py
+2
-2
src/pass/loop_partition.cc
+17
-2
src/pass/ssa.cc
+128
-109
tests/python/unittest/test_build_lower.py
+20
-0
No files found.
python/tvm/build.py
View file @
03b09f74
...
...
@@ -90,10 +90,10 @@ def lower(sch,
sch
=
sch
.
normalize
()
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
if
not
simple_mode
:
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
if
not
simple_mode
:
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
...
...
src/pass/loop_partition.cc
View file @
03b09f74
...
...
@@ -52,7 +52,7 @@ class CandidateSelector : public IRVisitor {
const
Variable
*
var
=
op
->
loop_var
.
get
();
record_
.
insert
({
var
,
false
});
IRVisitor
::
Visit_
(
op
);
if
(
record_
.
at
(
var
))
{
if
(
record_
.
at
(
var
)
&&
!
no_split_
)
{
candidates
.
insert
(
op
);
}
record_
.
erase
(
var
);
...
...
@@ -70,7 +70,7 @@ class CandidateSelector : public IRVisitor {
if
((
scope
.
rank
==
0
)
&&
!
is_const
(
op
->
value
))
{
record_
.
insert
({
var
.
get
(),
false
});
IRVisitor
::
Visit_
(
op
);
if
(
record_
.
at
(
var
.
get
()))
{
if
(
record_
.
at
(
var
.
get
())
&&
!
no_split_
)
{
candidates
.
insert
(
op
);
}
record_
.
erase
(
var
.
get
());
...
...
@@ -80,11 +80,25 @@ class CandidateSelector : public IRVisitor {
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Block
*
op
)
{
bool
temp
=
no_split_
;
this
->
Visit
(
op
->
first
);
// erase the no split state of first when visit rest.
std
::
swap
(
temp
,
no_split_
);
this
->
Visit
(
op
->
rest
);
// restore the no split flag.
no_split_
=
no_split_
||
temp
;
}
void
Visit_
(
const
Call
*
op
)
{
if
(
op
->
is_intrinsic
(
Call
::
likely
))
{
in_likely_
=
true
;
IRVisitor
::
Visit_
(
op
);
in_likely_
=
false
;
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_thread_allreduce
))
{
// no split if the body contains allreduce.
no_split_
=
true
;
return
;
}
else
{
IRVisitor
::
Visit_
(
op
);
}
...
...
@@ -100,6 +114,7 @@ class CandidateSelector : public IRVisitor {
private
:
bool
in_likely_
;
bool
no_split_
{
false
};
std
::
unordered_map
<
const
Variable
*
,
VarIsUsed
>
record_
;
};
...
...
src/pass/ssa.cc
View file @
03b09f74
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
*
* SSA requires each varaible to be only defined once.
* \file ssa.cc
*/
#include <tvm/ir.h>
...
...
@@ -14,138 +16,155 @@
namespace
tvm
{
namespace
ir
{
namespace
{
// global functor to get var definition from
struct
FGetVarDef
{
using
FType
=
IRFunctor
<
VarExpr
(
const
NodeRef
&
)
>
;
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
inst
;
return
inst
;
}
};
TVM_STATIC_IR_FUNCTOR
(
FGetVarDef
,
vtable
)
.
set_dispatch
<
Let
>
([](
const
Let
*
op
)
{
return
op
->
var
;
})
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
)
{
return
op
->
var
;
})
.
set_dispatch
<
For
>
([](
const
For
*
op
)
{
return
op
->
loop_var
;
})
.
set_dispatch
<
Allocate
>
([](
const
Allocate
*
op
)
{
return
op
->
buffer_var
;
});
struct
FSetVarDef
{
using
FTypeExpr
=
IRFunctor
<
Expr
(
const
NodeRef
&
,
VarExpr
)
>
;
using
FTypeStmt
=
IRFunctor
<
Stmt
(
const
NodeRef
&
,
VarExpr
)
>
;
static
FTypeExpr
&
vtable_expr
()
{
// NOLINT(*)
static
FTypeExpr
inst
;
return
inst
;
}
static
FTypeStmt
&
vtable_stmt
()
{
// NOLINT(*)
static
FTypeStmt
inst
;
return
inst
;
}
};
TVM_STATIC_IR_FUNCTOR
(
FSetVarDef
,
vtable_expr
)
.
set_dispatch
<
Let
>
([](
const
Let
*
op
,
VarExpr
var
)
{
std
::
shared_ptr
<
Let
>
x
=
std
::
make_shared
<
Let
>
(
*
op
);
x
->
var
=
var
;
return
Expr
(
x
);
});
TVM_STATIC_IR_FUNCTOR
(
FSetVarDef
,
vtable_stmt
)
.
set_dispatch
<
LetStmt
>
([](
const
LetStmt
*
op
,
VarExpr
var
)
{
std
::
shared_ptr
<
LetStmt
>
x
=
std
::
make_shared
<
LetStmt
>
(
*
op
);
x
->
var
=
var
;
return
Stmt
(
x
);
})
.
set_dispatch
<
For
>
([](
const
For
*
op
,
VarExpr
var
)
{
std
::
shared_ptr
<
For
>
x
=
std
::
make_shared
<
For
>
(
*
op
);
x
->
loop_var
=
var
;
return
Stmt
(
x
);
});
class
IRVerifySSA
:
public
IRVisitor
{
class
IRVerifySSA
final
:
public
IRVisitor
{
public
:
bool
is_ssa
{
true
};
void
Visit
(
const
NodeRef
&
n
)
final
{
if
(
!
is_ssa
)
return
;
static
auto
&
fget_var_def
=
FGetVarDef
::
vtable
();
if
(
fget_var_def
.
can_dispatch
(
n
))
{
VarExpr
v
=
fget_var_def
(
n
);
if
(
defined_
.
count
(
v
.
get
())
!=
0
)
{
is_ssa
=
false
;
return
;
}
else
{
defined_
[
v
.
get
()]
=
1
;
}
}
IRVisitor
::
Visit
(
n
);
}
void
Visit_
(
const
Let
*
op
)
final
{
MarkDef
(
op
->
var
.
get
());
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
LetStmt
*
op
)
final
{
MarkDef
(
op
->
var
.
get
());
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
For
*
op
)
final
{
MarkDef
(
op
->
loop_var
.
get
());
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Allocate
*
op
)
final
{
MarkDef
(
op
->
buffer_var
.
get
());
IRVisitor
::
Visit_
(
op
);
}
private
:
void
MarkDef
(
const
Variable
*
v
)
{
if
(
defined_
.
count
(
v
)
!=
0
)
{
is_ssa
=
false
;
return
;
}
else
{
defined_
[
v
]
=
1
;
}
}
std
::
unordered_map
<
const
Variable
*
,
int
>
defined_
;
};
class
IRConvertSSA
:
public
IRMutator
{
class
IRConvertSSA
final
:
public
IRMutator
{
public
:
Expr
Mutate
(
Expr
expr
)
final
{
static
auto
&
fget_var_def
=
FGetVarDef
::
vtable
();
static
auto
&
fset_var_def
=
FSetVarDef
::
vtable_expr
();
if
(
fget_var_def
.
can_dispatch
(
expr
))
{
VarExpr
v
=
fget_var_def
(
expr
);
VarExpr
new_var
=
v
;
if
(
defined_
.
count
(
v
.
get
())
!=
0
)
{
CHECK
(
expr
.
as
<
Allocate
>
()
==
nullptr
)
<<
"One allocation in two places, cannot rename buffer in allocate"
;
new_var
=
Variable
::
make
(
v
->
type
,
v
->
name_hint
);
}
else
{
defined_
.
insert
(
v
.
get
());
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
if
(
scope_
.
count
(
op
))
{
return
scope_
[
op
].
back
();
}
else
{
return
e
;
}
}
Expr
Mutate_
(
const
Let
*
op
,
const
Expr
&
e
)
final
{
const
VarExpr
&
v
=
op
->
var
;
if
(
defined_
.
count
(
v
.
get
()))
{
Expr
value
=
IRMutator
::
Mutate
(
op
->
value
);
VarExpr
new_var
=
Variable
::
make
(
v
.
type
(),
v
->
name_hint
);
scope_
[
v
.
get
()].
push_back
(
new_var
);
Expr
new_expr
=
IRMutator
::
Mutate
(
expr
);
Expr
body
=
IRMutator
::
Mutate
(
op
->
body
);
scope_
[
v
.
get
()].
pop_back
();
if
(
!
new_var
.
same_as
(
v
))
{
return
fset_var_def
(
new_expr
,
new_var
);
}
else
{
return
new_expr
;
}
}
else
if
(
expr
.
as
<
Variable
>
())
{
const
Variable
*
v
=
expr
.
as
<
Variable
>
();
if
(
scope_
.
count
(
v
)
!=
0
)
{
return
scope_
[
v
].
back
();
}
else
{
return
expr
;
}
return
Let
::
make
(
new_var
,
value
,
body
);
}
else
{
Expr
e
=
IRMutator
::
Mutate
(
expr
);
return
e
;
defined_
.
insert
(
v
.
get
()
);
return
IRMutator
::
Mutate_
(
op
,
e
)
;
}
}
Stmt
Mutate
(
Stmt
stmt
)
final
{
static
auto
&
fget_var_def
=
FGetVarDef
::
vtable
();
static
auto
&
fset_var_def
=
FSetVarDef
::
vtable_stmt
();
if
(
fget_var_def
.
can_dispatch
(
stmt
))
{
VarExpr
v
=
fget_var_def
(
stmt
);
VarExpr
new_var
=
v
;
if
(
defined_
.
count
(
v
.
get
())
!=
0
)
{
new_var
=
Variable
::
make
(
v
->
type
,
v
->
name_hint
);
}
else
{
defined_
.
insert
(
v
.
get
());
}
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
if
(
scope_
.
count
(
op
->
buffer_var
.
get
()))
{
return
Load
::
make
(
op
->
type
,
scope_
[
op
->
buffer_var
.
get
()].
back
(),
op
->
index
,
op
->
predicate
);
}
else
{
return
expr
;
}
}
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Store
>
();
if
(
scope_
.
count
(
op
->
buffer_var
.
get
()))
{
return
Store
::
make
(
scope_
[
op
->
buffer_var
.
get
()].
back
(),
op
->
value
,
op
->
index
,
op
->
predicate
);
}
else
{
return
stmt
;
}
}
Stmt
Mutate_
(
const
LetStmt
*
op
,
const
Stmt
&
s
)
final
{
const
VarExpr
&
v
=
op
->
var
;
if
(
defined_
.
count
(
v
.
get
()))
{
Expr
value
=
IRMutator
::
Mutate
(
op
->
value
);
VarExpr
new_var
=
Variable
::
make
(
v
.
type
(),
v
->
name_hint
);
scope_
[
v
.
get
()].
push_back
(
new_var
);
Stmt
new_stmt
=
IRMutator
::
Mutate
(
stmt
);
Stmt
body
=
IRMutator
::
Mutate
(
op
->
body
);
scope_
[
v
.
get
()].
pop_back
();
if
(
!
new_var
.
same_as
(
v
))
{
return
fset_var_def
(
new_stmt
,
new_var
);
return
LetStmt
::
make
(
new_var
,
value
,
body
);
}
else
{
defined_
.
insert
(
v
.
get
());
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
const
VarExpr
&
v
=
op
->
loop_var
;
if
(
defined_
.
count
(
v
.
get
()))
{
VarExpr
new_var
=
Variable
::
make
(
v
.
type
(),
v
->
name_hint
);
scope_
[
v
.
get
()].
push_back
(
new_var
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
scope_
[
v
.
get
()].
pop_back
();
op
=
stmt
.
as
<
For
>
();
return
For
::
make
(
new_var
,
op
->
min
,
op
->
extent
,
op
->
for_type
,
op
->
device_api
,
op
->
body
);
}
else
{
defined_
.
insert
(
v
.
get
());
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
const
VarExpr
&
v
=
op
->
buffer_var
;
if
(
defined_
.
count
(
v
.
get
()))
{
VarExpr
new_var
=
Variable
::
make
(
v
.
type
(),
v
->
name_hint
);
scope_
[
v
.
get
()].
push_back
(
new_var
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
scope_
[
v
.
get
()].
pop_back
();
op
=
stmt
.
as
<
Allocate
>
();
return
Allocate
::
make
(
new_var
,
op
->
type
,
op
->
extents
,
op
->
condition
,
op
->
body
,
op
->
new_expr
,
op
->
free_function
);
}
else
{
defined_
.
insert
(
v
.
get
());
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
())
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
const
Allocate
*
alloc
=
op
->
body
.
as
<
Allocate
>
();
if
(
alloc
&&
op
->
node
.
same_as
(
alloc
->
buffer_var
))
{
Stmt
new_alloc
=
Mutate
(
op
->
body
);
if
(
new_alloc
.
same_as
(
op
->
body
))
return
s
;
alloc
=
new_alloc
.
as
<
Allocate
>
();
CHECK
(
alloc
);
return
AttrStmt
::
make
(
alloc
->
buffer_var
,
op
->
attr_key
,
op
->
value
,
new_alloc
);
}
}
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
scope_
.
count
(
v
)
&&
scope_
[
v
].
size
()
!=
0
)
{
return
AttrStmt
::
make
(
scope_
[
v
].
back
(),
op
->
attr_key
,
op
->
value
,
op
->
body
);
}
else
{
return
new_
stmt
;
return
stmt
;
}
}
else
{
return
IRMutator
::
Mutate
(
stmt
);
return
IRMutator
::
Mutate
_
(
op
,
s
);
}
}
...
...
tests/python/unittest/test_build_lower.py
0 → 100644
View file @
03b09f74
import
tvm
def
test_lower_rfactor
():
n
=
tvm
.
var
(
"n"
)
m
=
tvm
.
var
(
"m"
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
k
=
tvm
.
reduce_axis
((
0
,
m
),
"k"
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
sum
(
A
[
i
,
k
],
axis
=
k
),
name
=
"B"
)
s
=
tvm
.
create_schedule
(
B
.
op
)
ko
,
ki
=
s
[
B
]
.
split
(
B
.
op
.
reduce_axis
[
0
],
factor
=
16
)
BF
=
s
.
rfactor
(
B
,
ki
)
xo
,
xi
=
s
[
B
]
.
split
(
s
[
B
]
.
op
.
axis
[
0
],
factor
=
32
)
s
[
B
.
op
]
.
bind
(
xo
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
.
op
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"threadIdx.y"
))
s
[
B
]
.
bind
(
s
[
B
]
.
op
.
reduce_axis
[
0
],
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
BF
]
.
compute_at
(
s
[
B
],
s
[
B
]
.
op
.
reduce_axis
[
0
])
fapi
=
tvm
.
lower
(
s
,
[
A
,
B
])
if
__name__
==
"__main__"
:
test_lower_rfactor
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment