Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0876e9e9
Commit
0876e9e9
authored
Apr 09, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 09, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Rename attr_key in AttrStmt (#83)
parent
8f51c5fd
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
54 additions
and
54 deletions
+54
-54
HalideIR
+1
-1
src/arithmetic/canonical.cc
+2
-2
src/codegen/codegen_c.cc
+3
-3
src/codegen/llvm/codegen_llvm.cc
+1
-1
src/codegen/verilog/verilog_ir.cc
+9
-9
src/pass/inject_virtual_thread.cc
+2
-2
src/pass/ir_mutator.cc
+1
-1
src/pass/lift_allocate.cc
+4
-4
src/pass/lower_thread_allreduce.cc
+2
-2
src/pass/narrow_channel_access.cc
+5
-5
src/pass/split_host_device.cc
+6
-6
src/pass/split_pipeline.cc
+4
-4
src/pass/storage_flatten.cc
+3
-3
src/pass/storage_sync.cc
+1
-1
src/schedule/schedule_ops.cc
+10
-10
No files found.
HalideIR
@
d024efd8
Subproject commit
59fdca16978b6184bab87fbff7a00c95f180468
6
Subproject commit
d024efd80694556c1239c4435c5b3e70853a489
6
src/arithmetic/canonical.cc
View file @
0876e9e9
...
@@ -286,8 +286,8 @@ class Canonical::Internal : public IRMutator {
...
@@ -286,8 +286,8 @@ class Canonical::Internal : public IRMutator {
}
}
// AttrStmt
// AttrStmt
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
{
if
(
op
->
type
_key
==
attr
::
thread_extent
||
if
(
op
->
attr
_key
==
attr
::
thread_extent
||
op
->
type
_key
==
attr
::
virtual_thread
)
{
op
->
attr
_key
==
attr
::
virtual_thread
)
{
++
level_counter_
;
++
level_counter_
;
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
...
...
src/codegen/codegen_c.cc
View file @
0876e9e9
...
@@ -654,7 +654,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
...
@@ -654,7 +654,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
}
}
void
CodeGenC
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
void
CodeGenC
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
if
(
op
->
type
_key
==
ir
::
attr
::
thread_extent
)
{
if
(
op
->
attr
_key
==
ir
::
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
if
(
iv
->
thread_tag
.
length
()
!=
0
)
{
if
(
iv
->
thread_tag
.
length
()
!=
0
)
{
if
(
!
var_idmap_
.
count
(
iv
->
var
.
get
()))
{
if
(
!
var_idmap_
.
count
(
iv
->
var
.
get
()))
{
...
@@ -667,11 +667,11 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
...
@@ -667,11 +667,11 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
stream
<<
";
\n
"
;
stream
<<
";
\n
"
;
}
}
}
}
}
else
if
(
op
->
type
_key
==
ir
::
attr
::
storage_scope
)
{
}
else
if
(
op
->
attr
_key
==
ir
::
attr
::
storage_scope
)
{
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
CHECK
(
v
);
CHECK
(
v
);
alloc_storage_scope_
[
v
]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
alloc_storage_scope_
[
v
]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
}
else
if
(
op
->
type
_key
==
ir
::
attr
::
volatile_scope
)
{
}
else
if
(
op
->
attr
_key
==
ir
::
attr
::
volatile_scope
)
{
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
CHECK
(
v
);
CHECK
(
v
);
volatile_buf_
.
insert
(
v
);
volatile_buf_
.
insert
(
v
);
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
0876e9e9
...
@@ -1245,7 +1245,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
...
@@ -1245,7 +1245,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
}
}
void
CodeGenLLVM
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
void
CodeGenLLVM
::
VisitStmt_
(
const
AttrStmt
*
op
)
{
if
(
op
->
type
_key
==
ir
::
attr
::
storage_scope
)
{
if
(
op
->
attr
_key
==
ir
::
attr
::
storage_scope
)
{
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
CHECK
(
v
);
CHECK
(
v
);
alloc_storage_scope_
[
v
]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
alloc_storage_scope_
[
v
]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
...
...
src/codegen/verilog/verilog_ir.cc
View file @
0876e9e9
...
@@ -93,20 +93,20 @@ class PipelineExtractor: public IRVisitor {
...
@@ -93,20 +93,20 @@ class PipelineExtractor: public IRVisitor {
}
}
void
Visit_
(
const
AttrStmt
*
op
)
final
{
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
op
->
type
_key
==
attr
::
pipeline_stage_scope
)
{
if
(
op
->
attr
_key
==
attr
::
pipeline_stage_scope
)
{
CHECK
(
!
in_pipeline_stage_
);
CHECK
(
!
in_pipeline_stage_
);
in_pipeline_stage_
=
true
;
in_pipeline_stage_
=
true
;
trigger_
.
emplace_back
(
std
::
make_pair
(
loop_
.
size
(),
op
));
trigger_
.
emplace_back
(
std
::
make_pair
(
loop_
.
size
(),
op
));
IRVisitor
::
Visit_
(
op
);
IRVisitor
::
Visit_
(
op
);
trigger_
.
pop_back
();
trigger_
.
pop_back
();
in_pipeline_stage_
=
false
;
in_pipeline_stage_
=
false
;
}
else
if
(
op
->
type
_key
==
attr
::
channel_read_advance
||
}
else
if
(
op
->
attr
_key
==
attr
::
channel_read_advance
||
op
->
type
_key
==
attr
::
channel_write_advance
)
{
op
->
attr
_key
==
attr
::
channel_write_advance
)
{
trigger_
.
emplace_back
(
std
::
make_pair
(
loop_
.
size
(),
op
));
trigger_
.
emplace_back
(
std
::
make_pair
(
loop_
.
size
(),
op
));
IRVisitor
::
Visit_
(
op
);
IRVisitor
::
Visit_
(
op
);
trigger_
.
pop_back
();
trigger_
.
pop_back
();
}
else
if
(
op
->
type
_key
==
attr
::
channel_read_scope
||
}
else
if
(
op
->
attr
_key
==
attr
::
channel_read_scope
||
op
->
type
_key
==
attr
::
channel_write_scope
)
{
op
->
attr
_key
==
attr
::
channel_write_scope
)
{
Channel
ch
(
op
->
node
.
node_
);
Channel
ch
(
op
->
node
.
node_
);
ChannelEntry
&
cb
=
cmap_
[
ch
->
handle_var
.
get
()];
ChannelEntry
&
cb
=
cmap_
[
ch
->
handle_var
.
get
()];
if
(
cb
.
node
!=
nullptr
)
{
if
(
cb
.
node
!=
nullptr
)
{
...
@@ -115,7 +115,7 @@ class PipelineExtractor: public IRVisitor {
...
@@ -115,7 +115,7 @@ class PipelineExtractor: public IRVisitor {
cb
.
node
=
std
::
make_shared
<
ChannelBlockNode
>
();
cb
.
node
=
std
::
make_shared
<
ChannelBlockNode
>
();
cb
.
node
->
channel
=
ch
;
cb
.
node
->
channel
=
ch
;
}
}
if
(
op
->
type
_key
==
attr
::
channel_read_scope
)
{
if
(
op
->
attr
_key
==
attr
::
channel_read_scope
)
{
CHECK_EQ
(
cb
.
read_ref_count
,
0
)
CHECK_EQ
(
cb
.
read_ref_count
,
0
)
<<
"One channel can only be read from one consumer"
;
<<
"One channel can only be read from one consumer"
;
++
cb
.
read_ref_count
;
++
cb
.
read_ref_count
;
...
@@ -173,7 +173,7 @@ class PipelineExtractor: public IRVisitor {
...
@@ -173,7 +173,7 @@ class PipelineExtractor: public IRVisitor {
for
(
const
auto
&
e
:
trigger_
)
{
for
(
const
auto
&
e
:
trigger_
)
{
const
AttrStmt
*
attr
=
e
.
second
;
const
AttrStmt
*
attr
=
e
.
second
;
Channel
ch
;
Channel
ch
;
if
(
attr
->
type
_key
==
attr
::
pipeline_stage_scope
)
{
if
(
attr
->
attr
_key
==
attr
::
pipeline_stage_scope
)
{
ch
=
arg_write
;
ch
=
arg_write
;
if
(
!
ch
.
defined
())
continue
;
if
(
!
ch
.
defined
())
continue
;
}
else
{
}
else
{
...
@@ -195,10 +195,10 @@ class PipelineExtractor: public IRVisitor {
...
@@ -195,10 +195,10 @@ class PipelineExtractor: public IRVisitor {
trigger
->
signal_index
=
static_cast
<
int
>
(
cb
.
node
->
ctrl_signals
.
size
());
trigger
->
signal_index
=
static_cast
<
int
>
(
cb
.
node
->
ctrl_signals
.
size
());
// Grab the advance constant size.
// Grab the advance constant size.
int
trigger_size
;
int
trigger_size
;
if
(
attr
->
type
_key
==
attr
::
pipeline_stage_scope
)
{
if
(
attr
->
attr
_key
==
attr
::
pipeline_stage_scope
)
{
cb
.
node
->
ctrl_signals
.
push_back
(
cb
.
node
->
ctrl_signals
.
push_back
(
ControlSignalNode
::
make
(
kComputeFinish
,
0
));
ControlSignalNode
::
make
(
kComputeFinish
,
0
));
}
else
if
(
attr
->
type
_key
==
attr
::
channel_read_advance
)
{
}
else
if
(
attr
->
attr
_key
==
attr
::
channel_read_advance
)
{
CHECK
(
arith
::
GetConstInt
(
attr
->
value
,
&
trigger_size
))
CHECK
(
arith
::
GetConstInt
(
attr
->
value
,
&
trigger_size
))
<<
"Only support constant advance size"
;
<<
"Only support constant advance size"
;
cb
.
node
->
ctrl_signals
.
push_back
(
cb
.
node
->
ctrl_signals
.
push_back
(
...
...
src/pass/inject_virtual_thread.cc
View file @
0876e9e9
...
@@ -200,7 +200,7 @@ class VTInjector : public IRMutator {
...
@@ -200,7 +200,7 @@ class VTInjector : public IRMutator {
body
.
same_as
(
op
->
body
))
{
body
.
same_as
(
op
->
body
))
{
return
s
;
return
s
;
}
else
{
}
else
{
return
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
value
,
body
);
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr
_key
,
value
,
body
);
}
}
}
}
}
}
...
@@ -388,7 +388,7 @@ class VirtualThreadInjector : public IRMutator {
...
@@ -388,7 +388,7 @@ class VirtualThreadInjector : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
AttrStmt
>
();
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
->
type
_key
==
attr
::
virtual_thread
)
{
if
(
op
->
attr
_key
==
attr
::
virtual_thread
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
int
nthread
=
static_cast
<
int
>
(
op
->
value
.
as
<
IntImm
>
()
->
value
);
int
nthread
=
static_cast
<
int
>
(
op
->
value
.
as
<
IntImm
>
()
->
value
);
VarTouchedAnalysis
vs
;
VarTouchedAnalysis
vs
;
...
...
src/pass/ir_mutator.cc
View file @
0876e9e9
...
@@ -68,7 +68,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
...
@@ -68,7 +68,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
body
.
same_as
(
op
->
body
))
{
body
.
same_as
(
op
->
body
))
{
return
s
;
return
s
;
}
else
{
}
else
{
return
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
value
,
body
);
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr
_key
,
value
,
body
);
}
}
}
}
...
...
src/pass/lift_allocate.cc
View file @
0876e9e9
...
@@ -25,9 +25,9 @@ class AllocateLifter : public IRMutator {
...
@@ -25,9 +25,9 @@ class AllocateLifter : public IRMutator {
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
CHECK
(
op
->
type
_key
!=
attr
::
virtual_thread
)
CHECK
(
op
->
attr
_key
!=
attr
::
virtual_thread
)
<<
"InjectVirtualThread before LiftStorageAlloc"
;
<<
"InjectVirtualThread before LiftStorageAlloc"
;
if
(
op
->
type
_key
==
attr
::
storage_scope
)
{
if
(
op
->
attr
_key
==
attr
::
storage_scope
)
{
StorageScope
sc
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
StorageScope
sc
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
allocs_
[
sc
].
emplace_back
(
allocs_
[
sc
].
emplace_back
(
AttrStmt
::
make
(
AttrStmt
::
make
(
...
@@ -35,7 +35,7 @@ class AllocateLifter : public IRMutator {
...
@@ -35,7 +35,7 @@ class AllocateLifter : public IRMutator {
op
->
value
,
Evaluate
::
make
(
0
)));
op
->
value
,
Evaluate
::
make
(
0
)));
storage_scope_
[
op
->
node
.
get
()]
=
sc
;
storage_scope_
[
op
->
node
.
get
()]
=
sc
;
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type
_key
==
attr
::
thread_extent
)
{
}
else
if
(
op
->
attr
_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
curr_thread_scope_
.
push_back
(
ts
);
curr_thread_scope_
.
push_back
(
ts
);
...
@@ -55,7 +55,7 @@ class AllocateLifter : public IRMutator {
...
@@ -55,7 +55,7 @@ class AllocateLifter : public IRMutator {
Stmt
body
=
MergeNest
(
vec
,
op
->
body
);
Stmt
body
=
MergeNest
(
vec
,
op
->
body
);
vec
.
clear
();
vec
.
clear
();
return
AttrStmt
::
make
(
return
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
op
->
value
,
body
);
op
->
node
,
op
->
attr
_key
,
op
->
value
,
body
);
}
}
}
}
return
stmt
;
return
stmt
;
...
...
src/pass/lower_thread_allreduce.cc
View file @
0876e9e9
...
@@ -20,12 +20,12 @@ class ThreadAllreduceBuilder : public IRMutator {
...
@@ -20,12 +20,12 @@ class ThreadAllreduceBuilder : public IRMutator {
:
warp_size_
(
warp_size
)
{}
:
warp_size_
(
warp_size
)
{}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
attr
::
thread_extent
)
{
if
(
op
->
attr
_key
==
attr
::
thread_extent
)
{
thread_extents_
.
push_back
(
op
);
thread_extents_
.
push_back
(
op
);
Stmt
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
thread_extents_
.
pop_back
();
thread_extents_
.
pop_back
();
return
ret
;
return
ret
;
}
else
if
(
op
->
type
_key
==
attr
::
storage_scope
)
{
}
else
if
(
op
->
attr
_key
==
attr
::
storage_scope
)
{
Stmt
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
ret
.
as
<
AttrStmt
>
();
op
=
ret
.
as
<
AttrStmt
>
();
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
const
Variable
*
v
=
op
->
node
.
as
<
Variable
>
();
...
...
src/pass/narrow_channel_access.cc
View file @
0876e9e9
...
@@ -107,14 +107,14 @@ class ChannelAccessRewriter : public IRMutator {
...
@@ -107,14 +107,14 @@ class ChannelAccessRewriter : public IRMutator {
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
ret
;
Stmt
ret
;
const
AttrStmt
*
adv
=
op
->
body
.
as
<
AttrStmt
>
();
const
AttrStmt
*
adv
=
op
->
body
.
as
<
AttrStmt
>
();
if
((
op
->
type
_key
==
ir
::
attr
::
channel_read_scope
&&
if
((
op
->
attr
_key
==
ir
::
attr
::
channel_read_scope
&&
adv
&&
adv
->
type
_key
==
ir
::
attr
::
channel_read_advance
)
||
adv
&&
adv
->
attr
_key
==
ir
::
attr
::
channel_read_advance
)
||
(
op
->
type
_key
==
ir
::
attr
::
channel_write_scope
&&
(
op
->
attr
_key
==
ir
::
attr
::
channel_write_scope
&&
adv
&&
adv
->
type
_key
==
ir
::
attr
::
channel_write_advance
))
{
adv
&&
adv
->
attr
_key
==
ir
::
attr
::
channel_write_advance
))
{
RewriteEntry
e
;
RewriteEntry
e
;
e
.
window
=
op
;
e
.
window
=
op
;
e
.
advance
=
adv
;
e
.
advance
=
adv
;
e
.
read_access
=
op
->
type
_key
==
ir
::
attr
::
channel_read_scope
;
e
.
read_access
=
op
->
attr
_key
==
ir
::
attr
::
channel_read_scope
;
tasks_
.
push_back
(
e
);
tasks_
.
push_back
(
e
);
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
ret
=
IRMutator
::
Mutate_
(
op
,
s
);
if
(
tasks_
.
back
().
rewrite_success
)
{
if
(
tasks_
.
back
().
rewrite_success
)
{
...
...
src/pass/split_host_device.cc
View file @
0876e9e9
...
@@ -18,7 +18,7 @@ namespace ir {
...
@@ -18,7 +18,7 @@ namespace ir {
class
IRUseDefAnalysis
:
public
IRMutator
{
class
IRUseDefAnalysis
:
public
IRMutator
{
public
:
public
:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
attr
::
thread_extent
)
{
if
(
op
->
attr
_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
CHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
// thread_extent can appear multiple times
// thread_extent can appear multiple times
...
@@ -35,9 +35,9 @@ class IRUseDefAnalysis : public IRMutator {
...
@@ -35,9 +35,9 @@ class IRUseDefAnalysis : public IRMutator {
}
}
Stmt
body
=
this
->
Mutate
(
op
->
body
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
value
)
&&
body
.
same_as
(
body
))
return
s
;
if
(
value
.
same_as
(
value
)
&&
body
.
same_as
(
body
))
return
s
;
return
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
value
,
body
);
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr
_key
,
value
,
body
);
}
else
if
(
op
->
type
_key
==
attr
::
channel_write_scope
||
}
else
if
(
op
->
attr
_key
==
attr
::
channel_write_scope
||
op
->
type
_key
==
attr
::
channel_read_scope
)
{
op
->
attr
_key
==
attr
::
channel_read_scope
)
{
Channel
ch
(
op
->
node
.
node_
);
Channel
ch
(
op
->
node
.
node_
);
if
(
!
use_count_
.
count
(
ch
->
handle_var
.
get
()))
{
if
(
!
use_count_
.
count
(
ch
->
handle_var
.
get
()))
{
this
->
HandleDef
(
ch
->
handle_var
.
get
());
this
->
HandleDef
(
ch
->
handle_var
.
get
());
...
@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator {
...
@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator {
class
HostDeviceSplitter
:
public
IRMutator
{
class
HostDeviceSplitter
:
public
IRMutator
{
public
:
public
:
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
attr
::
thread_extent
||
if
(
op
->
attr
_key
==
attr
::
thread_extent
||
op
->
type
_key
==
attr
::
pipeline_exec_scope
)
{
op
->
attr
_key
==
attr
::
pipeline_exec_scope
)
{
return
SplitDeviceFunc
(
s
);
return
SplitDeviceFunc
(
s
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
...
...
src/pass/split_pipeline.cc
View file @
0876e9e9
...
@@ -77,7 +77,7 @@ class MarkChannelAccess : public IRMutator {
...
@@ -77,7 +77,7 @@ class MarkChannelAccess : public IRMutator {
}
}
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
ir
::
attr
::
storage_scope
)
{
if
(
op
->
attr
_key
==
ir
::
attr
::
storage_scope
)
{
Var
buf_var
(
op
->
node
.
node_
);
Var
buf_var
(
op
->
node
.
node_
);
if
(
cmap_
.
count
(
buf_var
.
get
()))
return
Mutate
(
op
->
body
);
if
(
cmap_
.
count
(
buf_var
.
get
()))
return
Mutate
(
op
->
body
);
}
}
...
@@ -223,7 +223,7 @@ class StageSplitter : public IRMutator {
...
@@ -223,7 +223,7 @@ class StageSplitter : public IRMutator {
nest
.
emplace_back
(
IfThenElse
::
make
(
op
->
condition
,
no_op
));
nest
.
emplace_back
(
IfThenElse
::
make
(
op
->
condition
,
no_op
));
}
else
if
(
const
AttrStmt
*
op
=
s
.
as
<
AttrStmt
>
())
{
}
else
if
(
const
AttrStmt
*
op
=
s
.
as
<
AttrStmt
>
())
{
nest
.
emplace_back
(
AttrStmt
::
make
(
nest
.
emplace_back
(
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
op
->
value
,
no_op
));
op
->
node
,
op
->
attr
_key
,
op
->
value
,
no_op
));
}
else
if
(
s
.
as
<
ProducerConsumer
>
())
{
}
else
if
(
s
.
as
<
ProducerConsumer
>
())
{
}
else
if
(
s
.
as
<
Block
>
())
{
}
else
if
(
s
.
as
<
Block
>
())
{
}
else
if
(
const
Allocate
*
op
=
s
.
as
<
Allocate
>
())
{
}
else
if
(
const
Allocate
*
op
=
s
.
as
<
Allocate
>
())
{
...
@@ -266,7 +266,7 @@ class PipelineSplitter : public IRMutator {
...
@@ -266,7 +266,7 @@ class PipelineSplitter : public IRMutator {
:
split_load_
(
split_load
)
{}
:
split_load_
(
split_load
)
{}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
ir
::
attr
::
pipeline_exec_scope
)
{
if
(
op
->
attr
_key
==
ir
::
attr
::
pipeline_exec_scope
)
{
CHECK_LE
(
env_
.
size
(),
1U
);
CHECK_LE
(
env_
.
size
(),
1U
);
const
ProducerConsumer
*
env
=
nullptr
;
const
ProducerConsumer
*
env
=
nullptr
;
if
(
env_
.
size
()
==
1
)
{
if
(
env_
.
size
()
==
1
)
{
...
@@ -276,7 +276,7 @@ class PipelineSplitter : public IRMutator {
...
@@ -276,7 +276,7 @@ class PipelineSplitter : public IRMutator {
op
->
body
,
env
);
op
->
body
,
env
);
if
(
body
.
same_as
(
op
->
body
))
return
s
;
if
(
body
.
same_as
(
op
->
body
))
return
s
;
return
AttrStmt
::
make
(
return
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
op
->
value
,
body
);
op
->
node
,
op
->
attr
_key
,
op
->
value
,
body
);
}
else
{
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
...
...
src/pass/storage_flatten.cc
View file @
0876e9e9
...
@@ -40,17 +40,17 @@ class StorageFlattener : public IRMutator {
...
@@ -40,17 +40,17 @@ class StorageFlattener : public IRMutator {
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
attr
::
realize_scope
)
{
if
(
op
->
attr
_key
==
attr
::
realize_scope
)
{
storage_scope_
[
op
->
node
.
get
()]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
storage_scope_
[
op
->
node
.
get
()]
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type
_key
==
attr
::
thread_extent
)
{
}
else
if
(
op
->
attr
_key
==
attr
::
thread_extent
)
{
IterVar
iv
(
op
->
node
.
node_
);
IterVar
iv
(
op
->
node
.
node_
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
ThreadScope
ts
=
ThreadScope
::
make
(
iv
->
thread_tag
);
curr_thread_scope_
.
push_back
(
ts
);
curr_thread_scope_
.
push_back
(
ts
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
curr_thread_scope_
.
pop_back
();
curr_thread_scope_
.
pop_back
();
return
stmt
;
return
stmt
;
}
else
if
(
op
->
type
_key
==
attr
::
extern_op_scope
)
{
}
else
if
(
op
->
attr
_key
==
attr
::
extern_op_scope
)
{
return
HandleExternOp
(
op
);
return
HandleExternOp
(
op
);
}
}
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
...
...
src/pass/storage_sync.cc
View file @
0876e9e9
...
@@ -57,7 +57,7 @@ class StorageSyncPlanner : public IRVisitor {
...
@@ -57,7 +57,7 @@ class StorageSyncPlanner : public IRVisitor {
allow_load_
=
false
;
allow_load_
=
false
;
}
}
void
Visit_
(
const
AttrStmt
*
op
)
final
{
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
op
->
type
_key
==
"storage_scope"
)
{
if
(
op
->
attr
_key
==
"storage_scope"
)
{
const
Variable
*
buf
=
op
->
node
.
as
<
Variable
>
();
const
Variable
*
buf
=
op
->
node
.
as
<
Variable
>
();
storage_scope_
[
buf
]
=
storage_scope_
[
buf
]
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
...
...
src/schedule/schedule_ops.cc
View file @
0876e9e9
...
@@ -55,7 +55,7 @@ class InjectAttach : public IRMutator {
...
@@ -55,7 +55,7 @@ class InjectAttach : public IRMutator {
stmt
=
IRMutator
::
Mutate
(
stmt
);
stmt
=
IRMutator
::
Mutate
(
stmt
);
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
!=
nullptr
&&
if
(
op
!=
nullptr
&&
op
->
type
_key
==
attr
::
loop_scope
)
{
op
->
attr
_key
==
attr
::
loop_scope
)
{
if
(
attach_spec_
->
attach_type
==
kScope
&&
if
(
attach_spec_
->
attach_type
==
kScope
&&
op
->
node
==
attach_spec_
->
attach_ivar
)
{
op
->
node
==
attach_spec_
->
attach_ivar
)
{
CHECK
(
!
found_attach
)
CHECK
(
!
found_attach
)
...
@@ -63,7 +63,7 @@ class InjectAttach : public IRMutator {
...
@@ -63,7 +63,7 @@ class InjectAttach : public IRMutator {
<<
" in multiple places in the IR"
;
<<
" in multiple places in the IR"
;
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
op
->
value
,
op
->
node
,
op
->
attr
_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
}
}
}
}
...
@@ -97,12 +97,12 @@ class InjectScanStep : public IRMutator {
...
@@ -97,12 +97,12 @@ class InjectScanStep : public IRMutator {
// update
// update
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
const
AttrStmt
*
op
=
stmt
.
as
<
AttrStmt
>
();
if
(
op
!=
nullptr
&&
if
(
op
!=
nullptr
&&
((
op
->
type
_key
==
attr
::
scan_update_scope
&&
!
is_init_
)
||
((
op
->
attr
_key
==
attr
::
scan_update_scope
&&
!
is_init_
)
||
(
op
->
type
_key
==
attr
::
scan_init_scope
&&
is_init_
)))
{
(
op
->
attr
_key
==
attr
::
scan_init_scope
&&
is_init_
)))
{
if
(
op
->
node
.
same_as
(
scan_op_
))
{
if
(
op
->
node
.
same_as
(
scan_op_
))
{
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
type
_key
,
op
->
value
,
op
->
node
,
op
->
attr
_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
}
}
}
}
...
@@ -150,20 +150,20 @@ class SchedulePostProc : public IRMutator {
...
@@ -150,20 +150,20 @@ class SchedulePostProc : public IRMutator {
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
type
_key
==
attr
::
loop_scope
||
if
(
op
->
attr
_key
==
attr
::
loop_scope
||
op
->
type
_key
==
attr
::
scan_init_scope
)
{
op
->
attr
_key
==
attr
::
scan_init_scope
)
{
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type
_key
==
attr
::
scan_update_scope
)
{
}
else
if
(
op
->
attr
_key
==
attr
::
scan_update_scope
)
{
const
ScanOpNode
*
scan
=
op
->
node
.
as
<
ScanOpNode
>
();
const
ScanOpNode
*
scan
=
op
->
node
.
as
<
ScanOpNode
>
();
CHECK
(
scan
);
CHECK
(
scan
);
var_value_
[
scan
->
scan_axis
->
var
.
get
()]
=
op
->
value
;
var_value_
[
scan
->
scan_axis
->
var
.
get
()]
=
op
->
value
;
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
if
(
op
->
type
_key
==
ir
::
attr
::
realize_scope
)
{
}
else
if
(
op
->
attr
_key
==
ir
::
attr
::
realize_scope
)
{
auto
it
=
replace_op_
.
find
(
op
->
node
.
get
());
auto
it
=
replace_op_
.
find
(
op
->
node
.
get
());
if
(
it
!=
replace_op_
.
end
())
{
if
(
it
!=
replace_op_
.
end
())
{
if
(
it
->
second
.
defined
())
{
if
(
it
->
second
.
defined
())
{
Stmt
ret
=
AttrStmt
::
make
(
Stmt
ret
=
AttrStmt
::
make
(
it
->
second
,
op
->
type
_key
,
op
->
value
,
op
->
body
);
it
->
second
,
op
->
attr
_key
,
op
->
value
,
op
->
body
);
return
this
->
Mutate
(
ret
);
return
this
->
Mutate
(
ret
);
}
else
{
}
else
{
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment