Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5b8ff8d0
Commit
5b8ff8d0
authored
Feb 04, 2019
by
Alexey Romanov
Committed by
Tianqi Chen
Feb 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove duplicate as Checks and CHECK value (#2531)
parent
74b035a2
Show whitespace changes
Inline
Side-by-side
Showing
25 changed files
with
66 additions
and
72 deletions
+66
-72
src/codegen/codegen_c.cc
+2
-3
src/codegen/stackvm/codegen_stackvm.cc
+2
-2
src/common/socket.h
+2
-2
src/op/hybrid_op.cc
+1
-1
src/pass/ir_util.cc
+14
-14
src/pass/loop_partition.cc
+3
-2
src/pass/split_host_device.cc
+1
-1
src/pass/storage_rewrite.cc
+3
-3
src/pass/verify_gpu_code.cc
+10
-8
src/relay/backend/interpreter.cc
+1
-1
src/relay/ir/base.cc
+1
-1
src/relay/op/type_relations.cc
+1
-1
src/relay/pass/fold_scale_axis.cc
+1
-1
src/relay/pass/gradient.cc
+1
-1
src/relay/pass/to_anf.cc
+1
-1
src/relay/pass/type_infer.cc
+1
-1
src/relay/pass/type_solver.cc
+1
-1
src/runtime/rpc/rpc_session.cc
+0
-1
src/runtime/stackvm/stackvm.cc
+0
-2
src/runtime/stackvm/stackvm.h
+1
-1
src/schedule/graph.cc
+13
-13
src/schedule/message_passing.cc
+3
-6
src/schedule/schedule_dataflow_rewrite.cc
+1
-1
src/schedule/schedule_lang.cc
+1
-2
src/schedule/schedule_ops.cc
+1
-2
No files found.
src/codegen/codegen_c.cc
View file @
5b8ff8d0
...
@@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
...
@@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
void
CodeGenC
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
void
CodeGenC
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
std
::
string
cond
=
PrintExpr
(
op
->
condition
);
std
::
string
cond
=
PrintExpr
(
op
->
condition
);
PrintIndent
();
PrintIndent
();
if
(
op
->
message
.
as
<
StringImm
>
())
{
if
(
const
auto
*
str
=
op
->
message
.
as
<
StringImm
>
())
{
// GLOG style check
// GLOG style check
stream
<<
"CHECK("
<<
cond
<<
") <<
\"
"
stream
<<
"CHECK("
<<
cond
<<
") <<
\"
"
<<
str
->
value
<<
"
\"
;
\n
"
;
<<
op
->
message
.
as
<
StringImm
>
()
->
value
<<
"
\"
;
\n
"
;
}
else
{
}
else
{
stream
<<
"assert("
<<
cond
<<
");
\n
"
;
stream
<<
"assert("
<<
cond
<<
");
\n
"
;
}
}
...
...
src/codegen/stackvm/codegen_stackvm.cc
View file @
5b8ff8d0
...
@@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
...
@@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
}
}
void
CodeGenStackVM
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
void
CodeGenStackVM
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
if
(
op
->
message
.
as
<
StringImm
>
())
{
if
(
const
auto
*
str
=
op
->
message
.
as
<
StringImm
>
())
{
int
sid
=
this
->
GetStrID
(
op
->
message
.
as
<
StringImm
>
()
->
value
);
int
sid
=
this
->
GetStrID
(
str
->
value
);
this
->
Push
(
op
->
condition
);
this
->
Push
(
op
->
condition
);
this
->
PushOp
(
StackVM
::
ASSERT
,
sid
);
this
->
PushOp
(
StackVM
::
ASSERT
,
sid
);
}
}
...
...
src/common/socket.h
View file @
5b8ff8d0
...
@@ -42,13 +42,13 @@ inline std::string GetHostName() {
...
@@ -42,13 +42,13 @@ inline std::string GetHostName() {
}
}
/*!
/*!
* \brief Common data structure fornetwork address.
* \brief Common data structure for
network address.
*/
*/
struct
SockAddr
{
struct
SockAddr
{
sockaddr_storage
addr
;
sockaddr_storage
addr
;
SockAddr
()
{}
SockAddr
()
{}
/*!
/*!
* \brief construc address by url and port
* \brief construc
t
address by url and port
* \param url The url of the address
* \param url The url of the address
* \param port The port of the address.
* \param port The port of the address.
*/
*/
...
...
src/op/hybrid_op.cc
View file @
5b8ff8d0
...
@@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
...
@@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
// Gather rebased variables
// Gather rebased variables
std
::
unordered_map
<
IterVar
,
IterVar
>
rebased
;
std
::
unordered_map
<
IterVar
,
IterVar
>
rebased
;
for
(
auto
rel
:
stage
->
relations
)
{
for
(
auto
rel
:
stage
->
relations
)
{
if
(
auto
rebase
=
rel
.
as
<
RebaseNode
>
())
{
if
(
const
auto
*
rebase
=
rel
.
as
<
RebaseNode
>
())
{
rebased
[
rebase
->
rebased
]
=
rebase
->
parent
;
rebased
[
rebase
->
rebased
]
=
rebase
->
parent
;
CHECK
(
rebase
->
parent
->
dom
.
defined
());
CHECK
(
rebase
->
parent
->
dom
.
defined
());
CHECK
(
dom_map
.
count
(
rebase
->
rebased
));
CHECK
(
dom_map
.
count
(
rebase
->
rebased
));
...
...
src/pass/ir_util.cc
View file @
5b8ff8d0
...
@@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
...
@@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
// use reverse iteration
for
(
auto
ri
=
nest
.
rbegin
();
ri
!=
nest
.
rend
();
++
ri
)
{
for
(
auto
ri
=
nest
.
rbegin
();
ri
!=
nest
.
rend
();
++
ri
)
{
Stmt
s
=
*
ri
;
Stmt
s
=
*
ri
;
if
(
s
.
as
<
For
>
())
{
if
(
const
auto
*
for_
=
s
.
as
<
For
>
())
{
auto
n
=
make_node
<
For
>
(
*
s
.
as
<
For
>
()
);
auto
n
=
make_node
<
For
>
(
*
for_
);
CHECK
(
is_no_op
(
n
->
body
));
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
n
->
body
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
LetStmt
>
())
{
}
else
if
(
const
auto
*
let
=
s
.
as
<
LetStmt
>
())
{
auto
n
=
make_node
<
LetStmt
>
(
*
s
.
as
<
LetStmt
>
()
);
auto
n
=
make_node
<
LetStmt
>
(
*
let
);
CHECK
(
is_no_op
(
n
->
body
));
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
n
->
body
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AttrStmt
>
())
{
}
else
if
(
const
auto
*
attr
=
s
.
as
<
AttrStmt
>
())
{
auto
n
=
make_node
<
AttrStmt
>
(
*
s
.
as
<
AttrStmt
>
()
);
auto
n
=
make_node
<
AttrStmt
>
(
*
attr
);
CHECK
(
is_no_op
(
n
->
body
));
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
n
->
body
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
IfThenElse
>
())
{
}
else
if
(
const
auto
*
ite
=
s
.
as
<
IfThenElse
>
())
{
auto
n
=
make_node
<
IfThenElse
>
(
*
s
.
as
<
IfThenElse
>
()
);
auto
n
=
make_node
<
IfThenElse
>
(
*
ite
);
CHECK
(
is_no_op
(
n
->
then_case
));
CHECK
(
is_no_op
(
n
->
then_case
));
CHECK
(
!
n
->
else_case
.
defined
());
CHECK
(
!
n
->
else_case
.
defined
());
n
->
then_case
=
body
;
n
->
then_case
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
Block
>
())
{
}
else
if
(
const
auto
*
block
=
s
.
as
<
Block
>
())
{
auto
n
=
make_node
<
Block
>
(
*
s
.
as
<
Block
>
()
);
auto
n
=
make_node
<
Block
>
(
*
block
);
CHECK
(
is_no_op
(
n
->
rest
));
CHECK
(
is_no_op
(
n
->
rest
));
n
->
rest
=
body
;
n
->
rest
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AssertStmt
>
())
{
}
else
if
(
const
auto
*
assert_
=
s
.
as
<
AssertStmt
>
())
{
auto
n
=
make_node
<
AssertStmt
>
(
*
s
.
as
<
AssertStmt
>
()
);
auto
n
=
make_node
<
AssertStmt
>
(
*
assert_
);
CHECK
(
is_no_op
(
n
->
body
));
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
n
->
body
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
Allocate
>
())
{
}
else
if
(
const
auto
*
alloc
=
s
.
as
<
Allocate
>
())
{
auto
n
=
make_node
<
Allocate
>
(
*
s
.
as
<
Allocate
>
()
);
auto
n
=
make_node
<
Allocate
>
(
*
alloc
);
CHECK
(
is_no_op
(
n
->
body
));
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
n
->
body
=
body
;
body
=
Stmt
(
n
);
body
=
Stmt
(
n
);
...
...
src/pass/loop_partition.cc
View file @
5b8ff8d0
...
@@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
body_begin
;
Expr
body_begin
;
Stmt
pre_stmt
;
Stmt
pre_stmt
;
if
(
true_itrv
.
as
<
arith
::
IntervalSet
>
()
->
i
.
has_lower_bound
())
{
arith
::
Interval
true_itrv_i
=
true_itrv
.
as
<
arith
::
IntervalSet
>
()
->
i
;
if
(
true_itrv_i
.
has_lower_bound
())
{
body_begin
=
ir
::
Simplify
(
true_itrv
.
min
());
body_begin
=
ir
::
Simplify
(
true_itrv
.
min
());
if
(
!
can_prove
(
body_begin
==
min
))
{
if
(
!
can_prove
(
body_begin
==
min
))
{
Expr
cond
=
(
body_begin
-
min
>=
0
);
Expr
cond
=
(
body_begin
-
min
>=
0
);
...
@@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
post_doubt_begin
;
Expr
post_doubt_begin
;
Stmt
post_stmt
;
Stmt
post_stmt
;
if
(
true_itrv
.
as
<
arith
::
IntervalSet
>
()
->
i
.
has_upper_bound
())
{
if
(
true_itrv
_
i
.
has_upper_bound
())
{
post_doubt_begin
=
ir
::
Simplify
(
true_itrv
.
max
()
+
1
);
post_doubt_begin
=
ir
::
Simplify
(
true_itrv
.
max
()
+
1
);
if
(
!
can_prove
(
true_itrv
.
max
()
==
max
))
{
if
(
!
can_prove
(
true_itrv
.
max
()
==
max
))
{
// require the extent to be non-negative
// require the extent to be non-negative
...
...
src/pass/split_host_device.cc
View file @
5b8ff8d0
...
@@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator {
...
@@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator {
value
=
this
->
Mutate
(
value
);
value
=
this
->
Mutate
(
value
);
}
}
Stmt
body
=
this
->
Mutate
(
op
->
body
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
value
)
&&
body
.
same_as
(
body
))
return
s
;
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
return
s
;
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
value
,
body
);
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
value
,
body
);
}
else
if
(
op
->
attr_key
==
attr
::
channel_write_scope
||
}
else
if
(
op
->
attr_key
==
attr
::
channel_write_scope
||
op
->
attr_key
==
attr
::
channel_read_scope
)
{
op
->
attr_key
==
attr
::
channel_read_scope
)
{
...
...
src/pass/storage_rewrite.cc
View file @
5b8ff8d0
...
@@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator {
src_entry
->
attach_scope_
==
thread_scope_
&&
src_entry
->
attach_scope_
==
thread_scope_
&&
src_entry
->
elem_type
==
ae
.
alloc
->
type
.
element_of
()
&&
src_entry
->
elem_type
==
ae
.
alloc
->
type
.
element_of
()
&&
visitor
.
Check
(
s
.
stmt
,
var
,
src
))
{
visitor
.
Check
(
s
.
stmt
,
var
,
src
))
{
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
uint64_t
const_nbits
=
ae
.
alloc
->
constant_allocation_size
(
)
*
static_cast
<
uint64_t
>
(
ae
.
alloc
->
constant_allocation_size
()
)
*
ae
.
alloc
->
type
.
bits
()
*
ae
.
alloc
->
type
.
bits
()
*
ae
.
alloc
->
type
.
lanes
()
)
;
ae
.
alloc
->
type
.
lanes
();
if
(
src_entry
->
const_nbits
==
const_nbits
&&
!
inplace_found
)
{
if
(
src_entry
->
const_nbits
==
const_nbits
&&
!
inplace_found
)
{
// successfully inplace
// successfully inplace
dst_entry
=
src_entry
;
dst_entry
=
src_entry
;
...
...
src/pass/verify_gpu_code.cc
View file @
5b8ff8d0
...
@@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {
...
@@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {
void
Visit_
(
const
AttrStmt
*
op
)
{
void
Visit_
(
const
AttrStmt
*
op
)
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"local"
)
{
std
::
string
op_value
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
if
(
op_value
==
"local"
)
{
visited_local_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
visited_local_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
}
else
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"shared"
)
{
}
else
if
(
op
_
value
==
"shared"
)
{
visited_shared_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
visited_shared_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
}
}
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
...
@@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
...
@@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
int64_t
max_thread_z
=
INT64_MAX
;
int64_t
max_thread_z
=
INT64_MAX
;
for
(
auto
iter
:
constraints
)
{
for
(
auto
iter
:
constraints
)
{
const
IntImm
*
val
=
iter
.
second
.
as
<
IntImm
>
();
if
(
iter
.
first
==
"max_local_memory_per_block"
)
if
(
iter
.
first
==
"max_local_memory_per_block"
)
max_local_memory_per_block
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_local_memory_per_block
=
val
->
value
;
else
if
(
iter
.
first
==
"max_shared_memory_per_block"
)
else
if
(
iter
.
first
==
"max_shared_memory_per_block"
)
max_shared_memory_per_block
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_shared_memory_per_block
=
val
->
value
;
else
if
(
iter
.
first
==
"max_threads_per_block"
)
else
if
(
iter
.
first
==
"max_threads_per_block"
)
max_threads_per_block
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_threads_per_block
=
val
->
value
;
else
if
(
iter
.
first
==
"max_thread_x"
)
else
if
(
iter
.
first
==
"max_thread_x"
)
max_thread_x
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_thread_x
=
val
->
value
;
else
if
(
iter
.
first
==
"max_thread_y"
)
else
if
(
iter
.
first
==
"max_thread_y"
)
max_thread_y
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_thread_y
=
val
->
value
;
else
if
(
iter
.
first
==
"max_thread_z"
)
else
if
(
iter
.
first
==
"max_thread_z"
)
max_thread_z
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_thread_z
=
val
->
value
;
else
else
LOG
(
FATAL
)
<<
"Invalid check item: "
<<
iter
.
first
;
LOG
(
FATAL
)
<<
"Invalid check item: "
<<
iter
.
first
;
}
}
...
...
src/relay/backend/interpreter.cc
View file @
5b8ff8d0
...
@@ -379,7 +379,7 @@ class Interpreter :
...
@@ -379,7 +379,7 @@ class Interpreter :
//
//
// We have some functions cotaining chunks of operators
// We have some functions cotaining chunks of operators
// which will be loaded into operator map.
// which will be loaded into operator map.
if
(
auto
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
if
(
const
auto
*
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
LOG
(
FATAL
)
<<
"found "
<<
op_node
->
name
LOG
(
FATAL
)
<<
"found "
<<
op_node
->
name
<<
"; operators should be removed by future passes; try "
<<
"; operators should be removed by future passes; try "
"fusing and lowering"
;
"fusing and lowering"
;
...
...
src/relay/ir/base.cc
View file @
5b8ff8d0
...
@@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
...
@@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
auto
sn
=
source_map
.
find
(
name
);
auto
sn
=
source_map
.
find
(
name
);
if
(
sn
==
source_map
.
end
())
{
if
(
sn
==
source_map
.
end
())
{
NodePtr
<
SourceNameNode
>
n
=
make_node
<
SourceNameNode
>
();
NodePtr
<
SourceNameNode
>
n
=
make_node
<
SourceNameNode
>
();
n
->
name
=
std
::
move
(
name
);
source_map
[
name
]
=
n
;
source_map
[
name
]
=
n
;
n
->
name
=
std
::
move
(
name
);
return
n
;
return
n
;
}
else
{
}
else
{
return
sn
->
second
;
return
sn
->
second
;
...
...
src/relay/op/type_relations.cc
View file @
5b8ff8d0
...
@@ -15,7 +15,7 @@ namespace tvm {
...
@@ -15,7 +15,7 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
TensorType
ToTensorType
(
const
Type
&
t
)
{
TensorType
ToTensorType
(
const
Type
&
t
)
{
if
(
auto
tt_node
=
t
.
as
<
TensorTypeNode
>
())
{
if
(
const
auto
*
tt_node
=
t
.
as
<
TensorTypeNode
>
())
{
return
GetRef
<
TensorType
>
(
tt_node
);
return
GetRef
<
TensorType
>
(
tt_node
);
}
else
{
}
else
{
return
TensorType
(
nullptr
);
return
TensorType
(
nullptr
);
...
...
src/relay/pass/fold_scale_axis.cc
View file @
5b8ff8d0
...
@@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
...
@@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
rnode
->
scale
=
slhs
->
scale
;
rnode
->
scale
=
slhs
->
scale
;
rnode
->
axes
=
slhs
->
axes
;
rnode
->
axes
=
slhs
->
axes
;
}
else
{
}
else
{
CHECK
(
s
l
hs
!=
nullptr
);
CHECK
(
s
r
hs
!=
nullptr
);
CHECK
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
srhs
->
axes
));
CHECK
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
srhs
->
axes
));
Expr
scale
=
ExpandBiasToMatchAxis
(
Expr
scale
=
ExpandBiasToMatchAxis
(
srhs
->
scale
,
trhs
->
shape
.
size
(),
srhs
->
axes
);
srhs
->
scale
,
trhs
->
shape
.
size
(),
srhs
->
axes
);
...
...
src/relay/pass/gradient.cc
View file @
5b8ff8d0
...
@@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {
...
@@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {
//! \brief if the expression is a GlobalVar, transform to it's expression.
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr
DeGlobal
(
const
Module
&
mod
,
const
Expr
&
e
)
{
Expr
DeGlobal
(
const
Module
&
mod
,
const
Expr
&
e
)
{
if
(
auto
x
=
e
.
as
<
GlobalVarNode
>
())
{
if
(
const
auto
*
x
=
e
.
as
<
GlobalVarNode
>
())
{
return
mod
->
Lookup
(
GetRef
<
GlobalVar
>
(
x
))
->
body
;
return
mod
->
Lookup
(
GetRef
<
GlobalVar
>
(
x
))
->
body
;
}
else
{
}
else
{
return
e
;
return
e
;
...
...
src/relay/pass/to_anf.cc
View file @
5b8ff8d0
...
@@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
...
@@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}
}
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
if
(
auto
f
=
e
.
as
<
FunctionNode
>
())
{
if
(
const
auto
*
f
=
e
.
as
<
FunctionNode
>
())
{
return
FunctionNode
::
make
(
f
->
params
,
return
FunctionNode
::
make
(
f
->
params
,
ToANFAux
(
f
->
body
,
m
,
gv
),
ToANFAux
(
f
->
body
,
m
,
gv
),
f
->
ret_type
,
f
->
ret_type
,
...
...
src/relay/pass/type_infer.cc
View file @
5b8ff8d0
...
@@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
...
@@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
}
for
(
auto
cs
:
fn_ty
->
type_constraints
)
{
for
(
auto
cs
:
fn_ty
->
type_constraints
)
{
if
(
auto
tr
=
cs
.
as
<
TypeRelationNode
>
())
{
if
(
const
auto
*
tr
=
cs
.
as
<
TypeRelationNode
>
())
{
solver_
.
AddConstraint
(
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
tr
->
func
,
tr
->
args
,
tr
->
num_inputs
,
call
->
attrs
),
TypeRelationNode
::
make
(
tr
->
func
,
tr
->
args
,
tr
->
num_inputs
,
call
->
attrs
),
GetRef
<
Call
>
(
call
));
GetRef
<
Call
>
(
call
));
...
...
src/relay/pass/type_solver.cc
View file @
5b8ff8d0
...
@@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
...
@@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
// Add type constraint to the solver.
// Add type constraint to the solver.
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
,
const
NodeRef
&
loc
)
{
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
,
const
NodeRef
&
loc
)
{
if
(
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
if
(
const
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
// create a new relation node.
// create a new relation node.
RelationNode
*
rnode
=
arena_
.
make
<
RelationNode
>
();
RelationNode
*
rnode
=
arena_
.
make
<
RelationNode
>
();
rnode
->
location
=
loc
;
rnode
->
location
=
loc
;
...
...
src/runtime/rpc/rpc_session.cc
View file @
5b8ff8d0
...
@@ -486,7 +486,6 @@ class RPCSession::EventHandler : public dmlc::Stream {
...
@@ -486,7 +486,6 @@ class RPCSession::EventHandler : public dmlc::Stream {
arg_recv_stage_
=
1
;
arg_recv_stage_
=
1
;
this
->
RequestBytes
(
len
);
this
->
RequestBytes
(
len
);
break
;
break
;
break
;
}
}
case
kArrayHandle
:
{
case
kArrayHandle
:
{
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
...
...
src/runtime/stackvm/stackvm.cc
View file @
5b8ff8d0
...
@@ -406,7 +406,6 @@ void StackVM::Run(State* s) const {
...
@@ -406,7 +406,6 @@ void StackVM::Run(State* s) const {
case
intrinsic
:
:
kArrByteOffset
:
{
case
intrinsic
:
:
kArrByteOffset
:
{
stack
[
sp
].
v_int64
=
static_cast
<
int64_t
>
(
stack
[
sp
].
v_int64
=
static_cast
<
int64_t
>
(
arr
[
index
].
byte_offset
);
break
;
arr
[
index
].
byte_offset
);
break
;
break
;
}
}
case
intrinsic
:
:
kArrDeviceId
:
{
case
intrinsic
:
:
kArrDeviceId
:
{
stack
[
sp
].
v_int64
=
arr
[
index
].
ctx
.
device_id
;
break
;
stack
[
sp
].
v_int64
=
arr
[
index
].
ctx
.
device_id
;
break
;
...
@@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
...
@@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
if
(
f
==
nullptr
)
{
if
(
f
==
nullptr
)
{
CHECK
(
s
->
mod_ctx
!=
nullptr
)
CHECK
(
s
->
mod_ctx
!=
nullptr
)
<<
"No local context is set in stackvm"
;
<<
"No local context is set in stackvm"
;
CHECK
(
s
->
mod_ctx
!=
nullptr
);
const
PackedFunc
*
pf
=
s
->
mod_ctx
->
GetFuncFromEnv
(
extern_func_name
[
fid
]);
const
PackedFunc
*
pf
=
s
->
mod_ctx
->
GetFuncFromEnv
(
extern_func_name
[
fid
]);
CHECK
(
pf
!=
nullptr
);
CHECK
(
pf
!=
nullptr
);
f
=
*
pf
;
f
=
*
pf
;
...
...
src/runtime/stackvm/stackvm.h
View file @
5b8ff8d0
...
@@ -331,7 +331,7 @@ class StackVM {
...
@@ -331,7 +331,7 @@ class StackVM {
case
EQ_I64
:
return
EQ_F64
;
case
EQ_I64
:
return
EQ_F64
;
case
LT_I64
:
return
LT_F64
;
case
LT_I64
:
return
LT_F64
;
case
LE_I64
:
return
LE_F64
;
case
LE_I64
:
return
LE_F64
;
case
MOD_I64
:
LOG
(
FATAL
)
<<
"cannot handle mod for float"
;
case
MOD_I64
:
LOG
(
FATAL
)
<<
"cannot handle mod for float"
;
return
ADD_F64
;
default:
LOG
(
FATAL
)
<<
"cannot handle op "
<<
code
;
return
ADD_F64
;
default:
LOG
(
FATAL
)
<<
"cannot handle op "
<<
code
;
return
ADD_F64
;
}
}
}
}
...
...
src/schedule/graph.cc
View file @
5b8ff8d0
...
@@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
...
@@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
for
(
Operation
op
:
ops
)
{
for
(
Operation
op
:
ops
)
{
if
(
op
.
as
<
ScanOpNode
>
())
{
if
(
const
auto
*
scan_op
=
op
.
as
<
ScanOpNode
>
())
{
const
auto
&
update
=
op
.
as
<
ScanOpNode
>
()
->
update
;
const
auto
&
update
=
scan_op
->
update
;
const
auto
&
init
=
op
.
as
<
ScanOpNode
>
()
->
init
;
const
auto
&
init
=
scan_op
->
init
;
for
(
size_t
i
=
0
;
i
<
update
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
update
.
size
();
++
i
)
{
Tensor
t
=
op
.
output
(
i
);
Tensor
t
=
op
.
output
(
i
);
for
(
int
k
=
1
;
k
<
static_cast
<
int
>
(
update
[
i
]
->
shape
.
size
());
++
k
)
{
for
(
int
k
=
1
;
k
<
static_cast
<
int
>
(
update
[
i
]
->
shape
.
size
());
++
k
)
{
...
@@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
...
@@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
TensorDimKey
(
init
[
i
],
k
));
TensorDimKey
(
init
[
i
],
k
));
}
}
}
}
}
else
if
(
op
.
as
<
ComputeOpNode
>
())
{
}
else
if
(
const
auto
*
compute_op
=
op
.
as
<
ComputeOpNode
>
())
{
std
::
unordered_map
<
const
Node
*
,
TensorDimKey
>
vmap
;
std
::
unordered_map
<
const
Node
*
,
TensorDimKey
>
vmap
;
const
auto
&
axis
=
op
.
as
<
ComputeOpNode
>
()
->
axis
;
const
auto
&
axis
=
compute_op
->
axis
;
Tensor
t
=
op
.
output
(
0
);
Tensor
t
=
op
.
output
(
0
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
vmap
[
axis
[
i
]
->
var
.
get
()]
=
TensorDimKey
(
t
,
i
);
vmap
[
axis
[
i
]
->
var
.
get
()]
=
TensorDimKey
(
t
,
i
);
...
@@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
...
@@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
}
}
};
};
for
(
auto
&
e
:
op
.
as
<
ComputeOpNode
>
()
->
body
)
{
for
(
auto
&
e
:
compute_op
->
body
)
{
ir
::
PostOrderVisit
(
e
,
fvisit
);
ir
::
PostOrderVisit
(
e
,
fvisit
);
}
}
}
}
...
@@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
...
@@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
// prop exact reach back.
// prop exact reach back.
for
(
size_t
i
=
0
;
i
<
body
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
body
.
size
();
++
i
)
{
const
Operation
&
op
=
body
[
i
];
const
Operation
&
op
=
body
[
i
];
if
(
op
.
as
<
ScanOpNode
>
())
{
if
(
const
auto
*
scan_op
=
op
.
as
<
ScanOpNode
>
())
{
const
auto
&
update
=
op
.
as
<
ScanOpNode
>
()
->
update
;
const
auto
&
update
=
scan_op
->
update
;
const
auto
&
init
=
op
.
as
<
ScanOpNode
>
()
->
init
;
const
auto
&
init
=
scan_op
->
init
;
for
(
size_t
i
=
0
;
i
<
update
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
update
.
size
();
++
i
)
{
Tensor
t
=
op
.
output
(
i
);
Tensor
t
=
op
.
output
(
i
);
for
(
size_t
k
=
1
;
i
<
update
[
i
]
->
shape
.
size
();
++
k
)
{
for
(
size_t
k
=
1
;
k
<
update
[
i
]
->
shape
.
size
();
++
k
)
{
f_merge_key
(
TensorDimKey
(
t
,
k
),
TensorDimKey
(
update
[
i
],
k
));
f_merge_key
(
TensorDimKey
(
t
,
k
),
TensorDimKey
(
update
[
i
],
k
));
f_merge_key
(
TensorDimKey
(
t
,
k
),
TensorDimKey
(
init
[
i
],
k
));
f_merge_key
(
TensorDimKey
(
t
,
k
),
TensorDimKey
(
init
[
i
],
k
));
}
}
}
}
}
else
if
(
op
.
as
<
ComputeOpNode
>
())
{
}
else
if
(
const
auto
*
compute_op
=
op
.
as
<
ComputeOpNode
>
())
{
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
TensorDimKey
>
>
vmap
;
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
TensorDimKey
>
>
vmap
;
const
auto
&
axis
=
op
.
as
<
ComputeOpNode
>
()
->
axis
;
const
auto
&
axis
=
compute_op
->
axis
;
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
std
::
vector
<
TensorDimKey
>
keys
;
std
::
vector
<
TensorDimKey
>
keys
;
for
(
int
j
=
0
;
j
<
op
->
num_outputs
();
++
j
)
{
for
(
int
j
=
0
;
j
<
op
->
num_outputs
();
++
j
)
{
...
@@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
...
@@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
}
}
};
};
for
(
auto
&
e
:
op
.
as
<
ComputeOpNode
>
()
->
body
)
{
for
(
auto
&
e
:
compute_op
->
body
)
{
ir
::
PostOrderVisit
(
e
,
fvisit
);
ir
::
PostOrderVisit
(
e
,
fvisit
);
}
}
}
}
...
...
src/schedule/message_passing.cc
View file @
5b8ff8d0
...
@@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s,
...
@@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s,
using
HalideIR
::
Internal
::
can_prove
;
using
HalideIR
::
Internal
::
can_prove
;
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
if
(
rel
.
as
<
SplitNode
>
())
{
if
(
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
())
{
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
();
bool
outer
=
state
.
at
(
s
->
outer
);
bool
outer
=
state
.
at
(
s
->
outer
);
bool
inner
=
state
.
at
(
s
->
inner
);
bool
inner
=
state
.
at
(
s
->
inner
);
...
@@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s,
...
@@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s,
}
else
{
}
else
{
state
[
s
->
parent
]
=
true
;
state
[
s
->
parent
]
=
true
;
}
}
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
}
else
if
(
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
())
{
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
bool
fused
=
state
.
at
(
s
->
fused
);
bool
fused
=
state
.
at
(
s
->
fused
);
state
[
s
->
outer
]
=
fused
;
state
[
s
->
outer
]
=
fused
;
state
[
s
->
inner
]
=
fused
;
state
[
s
->
inner
]
=
fused
;
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
}
else
if
(
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
state
[
s
->
parent
]
=
state
.
at
(
s
->
rebased
);
state
[
s
->
parent
]
=
state
.
at
(
s
->
rebased
);
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
// nop
// nop
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
5b8ff8d0
...
@@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) {
...
@@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) {
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
if
(
compute
)
{
if
(
compute
)
{
if
(
!
new_body
[
j
].
size
())
{
if
(
!
new_body
[
j
].
size
())
{
new_body
[
j
]
=
s
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
new_body
[
j
]
=
compute
->
body
;
}
}
if
(
new_body
[
j
][
0
]
->
is_type
<
ir
::
Reduce
>
())
{
if
(
new_body
[
j
][
0
]
->
is_type
<
ir
::
Reduce
>
())
{
// specially handle reduction inline for multiplre reductions.
// specially handle reduction inline for multiplre reductions.
...
...
src/schedule/schedule_lang.cc
View file @
5b8ff8d0
...
@@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
...
@@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
n
->
stages
.
push_back
(
stage
);
n
->
stages
.
push_back
(
stage
);
n
->
stage_map
.
Set
(
op
,
stage
);
n
->
stage_map
.
Set
(
op
,
stage
);
// mark scan updates.
// mark scan updates.
if
(
op
.
as
<
ScanOpNode
>
())
{
if
(
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
();
Array
<
Tensor
>
inputs
;
Array
<
Tensor
>
inputs
;
for
(
Tensor
t
:
scan
->
state_placeholder
)
{
for
(
Tensor
t
:
scan
->
state_placeholder
)
{
inputs
.
push_back
(
t
);
inputs
.
push_back
(
t
);
...
...
src/schedule/schedule_ops.cc
View file @
5b8ff8d0
...
@@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator {
...
@@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator {
}
}
}
}
// Specially add replacements for scan op.
// Specially add replacements for scan op.
if
(
s
->
op
.
as
<
ScanOpNode
>
())
{
if
(
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
for
(
size_t
i
=
0
;
i
<
scan
->
update
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
scan
->
update
.
size
();
++
i
)
{
Tensor
t
=
s
->
origin_op
.
output
(
i
);
Tensor
t
=
s
->
origin_op
.
output
(
i
);
AddReplace
(
scan
->
init
[
i
],
t
);
AddReplace
(
scan
->
init
[
i
],
t
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment