Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5b8ff8d0
Commit
5b8ff8d0
authored
Feb 04, 2019
by
Alexey Romanov
Committed by
Tianqi Chen
Feb 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove duplicate as Checks and CHECK value (#2531)
parent
74b035a2
Hide whitespace changes
Inline
Side-by-side
Showing
25 changed files
with
88 additions
and
94 deletions
+88
-94
src/codegen/codegen_c.cc
+2
-3
src/codegen/stackvm/codegen_stackvm.cc
+2
-2
src/common/socket.h
+2
-2
src/op/hybrid_op.cc
+1
-1
src/pass/ir_util.cc
+14
-14
src/pass/loop_partition.cc
+3
-2
src/pass/split_host_device.cc
+1
-1
src/pass/storage_rewrite.cc
+3
-3
src/pass/verify_gpu_code.cc
+10
-8
src/relay/backend/interpreter.cc
+1
-1
src/relay/ir/base.cc
+1
-1
src/relay/op/type_relations.cc
+1
-1
src/relay/pass/fold_scale_axis.cc
+1
-1
src/relay/pass/gradient.cc
+1
-1
src/relay/pass/to_anf.cc
+1
-1
src/relay/pass/type_infer.cc
+1
-1
src/relay/pass/type_solver.cc
+1
-1
src/runtime/rpc/rpc_session.cc
+22
-23
src/runtime/stackvm/stackvm.cc
+0
-2
src/runtime/stackvm/stackvm.h
+1
-1
src/schedule/graph.cc
+13
-13
src/schedule/message_passing.cc
+3
-6
src/schedule/schedule_dataflow_rewrite.cc
+1
-1
src/schedule/schedule_lang.cc
+1
-2
src/schedule/schedule_ops.cc
+1
-2
No files found.
src/codegen/codegen_c.cc
View file @
5b8ff8d0
...
...
@@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
void
CodeGenC
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
std
::
string
cond
=
PrintExpr
(
op
->
condition
);
PrintIndent
();
if
(
op
->
message
.
as
<
StringImm
>
())
{
if
(
const
auto
*
str
=
op
->
message
.
as
<
StringImm
>
())
{
// GLOG style check
stream
<<
"CHECK("
<<
cond
<<
") <<
\"
"
<<
op
->
message
.
as
<
StringImm
>
()
->
value
<<
"
\"
;
\n
"
;
stream
<<
"CHECK("
<<
cond
<<
") <<
\"
"
<<
str
->
value
<<
"
\"
;
\n
"
;
}
else
{
stream
<<
"assert("
<<
cond
<<
");
\n
"
;
}
...
...
src/codegen/stackvm/codegen_stackvm.cc
View file @
5b8ff8d0
...
...
@@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
}
void
CodeGenStackVM
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
if
(
op
->
message
.
as
<
StringImm
>
())
{
int
sid
=
this
->
GetStrID
(
op
->
message
.
as
<
StringImm
>
()
->
value
);
if
(
const
auto
*
str
=
op
->
message
.
as
<
StringImm
>
())
{
int
sid
=
this
->
GetStrID
(
str
->
value
);
this
->
Push
(
op
->
condition
);
this
->
PushOp
(
StackVM
::
ASSERT
,
sid
);
}
...
...
src/common/socket.h
View file @
5b8ff8d0
...
...
@@ -42,13 +42,13 @@ inline std::string GetHostName() {
}
/*!
* \brief Common data structure fornetwork address.
* \brief Common data structure for
network address.
*/
struct
SockAddr
{
sockaddr_storage
addr
;
SockAddr
()
{}
/*!
* \brief construc address by url and port
* \brief construc
t
address by url and port
* \param url The url of the address
* \param port The port of the address.
*/
...
...
src/op/hybrid_op.cc
View file @
5b8ff8d0
...
...
@@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
// Gather rebased variables
std
::
unordered_map
<
IterVar
,
IterVar
>
rebased
;
for
(
auto
rel
:
stage
->
relations
)
{
if
(
auto
rebase
=
rel
.
as
<
RebaseNode
>
())
{
if
(
const
auto
*
rebase
=
rel
.
as
<
RebaseNode
>
())
{
rebased
[
rebase
->
rebased
]
=
rebase
->
parent
;
CHECK
(
rebase
->
parent
->
dom
.
defined
());
CHECK
(
dom_map
.
count
(
rebase
->
rebased
));
...
...
src/pass/ir_util.cc
View file @
5b8ff8d0
...
...
@@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
for
(
auto
ri
=
nest
.
rbegin
();
ri
!=
nest
.
rend
();
++
ri
)
{
Stmt
s
=
*
ri
;
if
(
s
.
as
<
For
>
())
{
auto
n
=
make_node
<
For
>
(
*
s
.
as
<
For
>
()
);
if
(
const
auto
*
for_
=
s
.
as
<
For
>
())
{
auto
n
=
make_node
<
For
>
(
*
for_
);
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
LetStmt
>
())
{
auto
n
=
make_node
<
LetStmt
>
(
*
s
.
as
<
LetStmt
>
()
);
}
else
if
(
const
auto
*
let
=
s
.
as
<
LetStmt
>
())
{
auto
n
=
make_node
<
LetStmt
>
(
*
let
);
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AttrStmt
>
())
{
auto
n
=
make_node
<
AttrStmt
>
(
*
s
.
as
<
AttrStmt
>
()
);
}
else
if
(
const
auto
*
attr
=
s
.
as
<
AttrStmt
>
())
{
auto
n
=
make_node
<
AttrStmt
>
(
*
attr
);
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
IfThenElse
>
())
{
auto
n
=
make_node
<
IfThenElse
>
(
*
s
.
as
<
IfThenElse
>
()
);
}
else
if
(
const
auto
*
ite
=
s
.
as
<
IfThenElse
>
())
{
auto
n
=
make_node
<
IfThenElse
>
(
*
ite
);
CHECK
(
is_no_op
(
n
->
then_case
));
CHECK
(
!
n
->
else_case
.
defined
());
n
->
then_case
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
Block
>
())
{
auto
n
=
make_node
<
Block
>
(
*
s
.
as
<
Block
>
()
);
}
else
if
(
const
auto
*
block
=
s
.
as
<
Block
>
())
{
auto
n
=
make_node
<
Block
>
(
*
block
);
CHECK
(
is_no_op
(
n
->
rest
));
n
->
rest
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AssertStmt
>
())
{
auto
n
=
make_node
<
AssertStmt
>
(
*
s
.
as
<
AssertStmt
>
()
);
}
else
if
(
const
auto
*
assert_
=
s
.
as
<
AssertStmt
>
())
{
auto
n
=
make_node
<
AssertStmt
>
(
*
assert_
);
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
Allocate
>
())
{
auto
n
=
make_node
<
Allocate
>
(
*
s
.
as
<
Allocate
>
()
);
}
else
if
(
const
auto
*
alloc
=
s
.
as
<
Allocate
>
())
{
auto
n
=
make_node
<
Allocate
>
(
*
alloc
);
CHECK
(
is_no_op
(
n
->
body
));
n
->
body
=
body
;
body
=
Stmt
(
n
);
...
...
src/pass/loop_partition.cc
View file @
5b8ff8d0
...
...
@@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
body_begin
;
Stmt
pre_stmt
;
if
(
true_itrv
.
as
<
arith
::
IntervalSet
>
()
->
i
.
has_lower_bound
())
{
arith
::
Interval
true_itrv_i
=
true_itrv
.
as
<
arith
::
IntervalSet
>
()
->
i
;
if
(
true_itrv_i
.
has_lower_bound
())
{
body_begin
=
ir
::
Simplify
(
true_itrv
.
min
());
if
(
!
can_prove
(
body_begin
==
min
))
{
Expr
cond
=
(
body_begin
-
min
>=
0
);
...
...
@@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
post_doubt_begin
;
Stmt
post_stmt
;
if
(
true_itrv
.
as
<
arith
::
IntervalSet
>
()
->
i
.
has_upper_bound
())
{
if
(
true_itrv
_
i
.
has_upper_bound
())
{
post_doubt_begin
=
ir
::
Simplify
(
true_itrv
.
max
()
+
1
);
if
(
!
can_prove
(
true_itrv
.
max
()
==
max
))
{
// require the extent to be non-negative
...
...
src/pass/split_host_device.cc
View file @
5b8ff8d0
...
...
@@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator {
value
=
this
->
Mutate
(
value
);
}
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
value
)
&&
body
.
same_as
(
body
))
return
s
;
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
return
s
;
return
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
value
,
body
);
}
else
if
(
op
->
attr_key
==
attr
::
channel_write_scope
||
op
->
attr_key
==
attr
::
channel_read_scope
)
{
...
...
src/pass/storage_rewrite.cc
View file @
5b8ff8d0
...
...
@@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator {
src_entry
->
attach_scope_
==
thread_scope_
&&
src_entry
->
elem_type
==
ae
.
alloc
->
type
.
element_of
()
&&
visitor
.
Check
(
s
.
stmt
,
var
,
src
))
{
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
ae
.
alloc
->
constant_allocation_size
(
)
*
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
ae
.
alloc
->
constant_allocation_size
()
)
*
ae
.
alloc
->
type
.
bits
()
*
ae
.
alloc
->
type
.
lanes
()
)
;
ae
.
alloc
->
type
.
lanes
();
if
(
src_entry
->
const_nbits
==
const_nbits
&&
!
inplace_found
)
{
// successfully inplace
dst_entry
=
src_entry
;
...
...
src/pass/verify_gpu_code.cc
View file @
5b8ff8d0
...
...
@@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {
void
Visit_
(
const
AttrStmt
*
op
)
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"local"
)
{
std
::
string
op_value
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
if
(
op_value
==
"local"
)
{
visited_local_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
}
else
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"shared"
)
{
}
else
if
(
op
_
value
==
"shared"
)
{
visited_shared_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
}
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
...
...
@@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
int64_t
max_thread_z
=
INT64_MAX
;
for
(
auto
iter
:
constraints
)
{
const
IntImm
*
val
=
iter
.
second
.
as
<
IntImm
>
();
if
(
iter
.
first
==
"max_local_memory_per_block"
)
max_local_memory_per_block
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_local_memory_per_block
=
val
->
value
;
else
if
(
iter
.
first
==
"max_shared_memory_per_block"
)
max_shared_memory_per_block
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_shared_memory_per_block
=
val
->
value
;
else
if
(
iter
.
first
==
"max_threads_per_block"
)
max_threads_per_block
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_threads_per_block
=
val
->
value
;
else
if
(
iter
.
first
==
"max_thread_x"
)
max_thread_x
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_thread_x
=
val
->
value
;
else
if
(
iter
.
first
==
"max_thread_y"
)
max_thread_y
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_thread_y
=
val
->
value
;
else
if
(
iter
.
first
==
"max_thread_z"
)
max_thread_z
=
(
iter
.
second
).
as
<
IntImm
>
()
->
value
;
max_thread_z
=
val
->
value
;
else
LOG
(
FATAL
)
<<
"Invalid check item: "
<<
iter
.
first
;
}
...
...
src/relay/backend/interpreter.cc
View file @
5b8ff8d0
...
...
@@ -379,7 +379,7 @@ class Interpreter :
//
// We have some functions cotaining chunks of operators
// which will be loaded into operator map.
if
(
auto
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
if
(
const
auto
*
op_node
=
call
->
op
.
as
<
OpNode
>
())
{
LOG
(
FATAL
)
<<
"found "
<<
op_node
->
name
<<
"; operators should be removed by future passes; try "
"fusing and lowering"
;
...
...
src/relay/ir/base.cc
View file @
5b8ff8d0
...
...
@@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
auto
sn
=
source_map
.
find
(
name
);
if
(
sn
==
source_map
.
end
())
{
NodePtr
<
SourceNameNode
>
n
=
make_node
<
SourceNameNode
>
();
n
->
name
=
std
::
move
(
name
);
source_map
[
name
]
=
n
;
n
->
name
=
std
::
move
(
name
);
return
n
;
}
else
{
return
sn
->
second
;
...
...
src/relay/op/type_relations.cc
View file @
5b8ff8d0
...
...
@@ -15,7 +15,7 @@ namespace tvm {
namespace
relay
{
TensorType
ToTensorType
(
const
Type
&
t
)
{
if
(
auto
tt_node
=
t
.
as
<
TensorTypeNode
>
())
{
if
(
const
auto
*
tt_node
=
t
.
as
<
TensorTypeNode
>
())
{
return
GetRef
<
TensorType
>
(
tt_node
);
}
else
{
return
TensorType
(
nullptr
);
...
...
src/relay/pass/fold_scale_axis.cc
View file @
5b8ff8d0
...
...
@@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
rnode
->
scale
=
slhs
->
scale
;
rnode
->
axes
=
slhs
->
axes
;
}
else
{
CHECK
(
s
l
hs
!=
nullptr
);
CHECK
(
s
r
hs
!=
nullptr
);
CHECK
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
srhs
->
axes
));
Expr
scale
=
ExpandBiasToMatchAxis
(
srhs
->
scale
,
trhs
->
shape
.
size
(),
srhs
->
axes
);
...
...
src/relay/pass/gradient.cc
View file @
5b8ff8d0
...
...
@@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr
DeGlobal
(
const
Module
&
mod
,
const
Expr
&
e
)
{
if
(
auto
x
=
e
.
as
<
GlobalVarNode
>
())
{
if
(
const
auto
*
x
=
e
.
as
<
GlobalVarNode
>
())
{
return
mod
->
Lookup
(
GetRef
<
GlobalVar
>
(
x
))
->
body
;
}
else
{
return
e
;
...
...
src/relay/pass/to_anf.cc
View file @
5b8ff8d0
...
...
@@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
if
(
auto
f
=
e
.
as
<
FunctionNode
>
())
{
if
(
const
auto
*
f
=
e
.
as
<
FunctionNode
>
())
{
return
FunctionNode
::
make
(
f
->
params
,
ToANFAux
(
f
->
body
,
m
,
gv
),
f
->
ret_type
,
...
...
src/relay/pass/type_infer.cc
View file @
5b8ff8d0
...
...
@@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
for
(
auto
cs
:
fn_ty
->
type_constraints
)
{
if
(
auto
tr
=
cs
.
as
<
TypeRelationNode
>
())
{
if
(
const
auto
*
tr
=
cs
.
as
<
TypeRelationNode
>
())
{
solver_
.
AddConstraint
(
TypeRelationNode
::
make
(
tr
->
func
,
tr
->
args
,
tr
->
num_inputs
,
call
->
attrs
),
GetRef
<
Call
>
(
call
));
...
...
src/relay/pass/type_solver.cc
View file @
5b8ff8d0
...
...
@@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
// Add type constraint to the solver.
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
,
const
NodeRef
&
loc
)
{
if
(
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
if
(
const
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
// create a new relation node.
RelationNode
*
rnode
=
arena_
.
make
<
RelationNode
>
();
rnode
->
location
=
loc
;
...
...
src/runtime/rpc/rpc_session.cc
View file @
5b8ff8d0
...
...
@@ -486,29 +486,28 @@ class RPCSession::EventHandler : public dmlc::Stream {
arg_recv_stage_
=
1
;
this
->
RequestBytes
(
len
);
break
;
break
;
}
case
kArrayHandle
:
{
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
uint64_t
handle
;
this
->
Read
(
&
handle
);
DLTensor
&
tensor
=
temp_array_
->
tensor
;
tensor
.
data
=
reinterpret_cast
<
void
*>
(
handle
);
this
->
Read
(
&
(
tensor
.
ctx
));
this
->
Read
(
&
(
tensor
.
ndim
));
this
->
Read
(
&
(
tensor
.
dtype
));
temp_array_
->
shape
.
resize
(
tensor
.
ndim
);
tensor
.
shape
=
temp_array_
->
shape
.
data
();
arg_recv_stage_
=
1
;
tensor
.
strides
=
nullptr
;
tensor
.
byte_offset
=
0
;
this
->
RequestBytes
(
sizeof
(
int64_t
)
*
tensor
.
ndim
);
break
;
}
default:
{
LOG
(
FATAL
)
<<
"RPC cannot handle type "
<<
TypeCode2Str
(
tcode
);
break
;
}
}
case
kArrayHandle
:
{
temp_array_
.
reset
(
new
RPCDataArrayBuffer
());
uint64_t
handle
;
this
->
Read
(
&
handle
);
DLTensor
&
tensor
=
temp_array_
->
tensor
;
tensor
.
data
=
reinterpret_cast
<
void
*>
(
handle
);
this
->
Read
(
&
(
tensor
.
ctx
));
this
->
Read
(
&
(
tensor
.
ndim
));
this
->
Read
(
&
(
tensor
.
dtype
));
temp_array_
->
shape
.
resize
(
tensor
.
ndim
);
tensor
.
shape
=
temp_array_
->
shape
.
data
();
arg_recv_stage_
=
1
;
tensor
.
strides
=
nullptr
;
tensor
.
byte_offset
=
0
;
this
->
RequestBytes
(
sizeof
(
int64_t
)
*
tensor
.
ndim
);
break
;
}
default:
{
LOG
(
FATAL
)
<<
"RPC cannot handle type "
<<
TypeCode2Str
(
tcode
);
break
;
}
}
}
else
{
CHECK_EQ
(
arg_recv_stage_
,
1
);
...
...
src/runtime/stackvm/stackvm.cc
View file @
5b8ff8d0
...
...
@@ -406,7 +406,6 @@ void StackVM::Run(State* s) const {
case
intrinsic
:
:
kArrByteOffset
:
{
stack
[
sp
].
v_int64
=
static_cast
<
int64_t
>
(
arr
[
index
].
byte_offset
);
break
;
break
;
}
case
intrinsic
:
:
kArrDeviceId
:
{
stack
[
sp
].
v_int64
=
arr
[
index
].
ctx
.
device_id
;
break
;
...
...
@@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
if
(
f
==
nullptr
)
{
CHECK
(
s
->
mod_ctx
!=
nullptr
)
<<
"No local context is set in stackvm"
;
CHECK
(
s
->
mod_ctx
!=
nullptr
);
const
PackedFunc
*
pf
=
s
->
mod_ctx
->
GetFuncFromEnv
(
extern_func_name
[
fid
]);
CHECK
(
pf
!=
nullptr
);
f
=
*
pf
;
...
...
src/runtime/stackvm/stackvm.h
View file @
5b8ff8d0
...
...
@@ -331,7 +331,7 @@ class StackVM {
case
EQ_I64
:
return
EQ_F64
;
case
LT_I64
:
return
LT_F64
;
case
LE_I64
:
return
LE_F64
;
case
MOD_I64
:
LOG
(
FATAL
)
<<
"cannot handle mod for float"
;
case
MOD_I64
:
LOG
(
FATAL
)
<<
"cannot handle mod for float"
;
return
ADD_F64
;
default:
LOG
(
FATAL
)
<<
"cannot handle op "
<<
code
;
return
ADD_F64
;
}
}
...
...
src/schedule/graph.cc
View file @
5b8ff8d0
...
...
@@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
for
(
Operation
op
:
ops
)
{
if
(
op
.
as
<
ScanOpNode
>
())
{
const
auto
&
update
=
op
.
as
<
ScanOpNode
>
()
->
update
;
const
auto
&
init
=
op
.
as
<
ScanOpNode
>
()
->
init
;
if
(
const
auto
*
scan_op
=
op
.
as
<
ScanOpNode
>
())
{
const
auto
&
update
=
scan_op
->
update
;
const
auto
&
init
=
scan_op
->
init
;
for
(
size_t
i
=
0
;
i
<
update
.
size
();
++
i
)
{
Tensor
t
=
op
.
output
(
i
);
for
(
int
k
=
1
;
k
<
static_cast
<
int
>
(
update
[
i
]
->
shape
.
size
());
++
k
)
{
...
...
@@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
TensorDimKey
(
init
[
i
],
k
));
}
}
}
else
if
(
op
.
as
<
ComputeOpNode
>
())
{
}
else
if
(
const
auto
*
compute_op
=
op
.
as
<
ComputeOpNode
>
())
{
std
::
unordered_map
<
const
Node
*
,
TensorDimKey
>
vmap
;
const
auto
&
axis
=
op
.
as
<
ComputeOpNode
>
()
->
axis
;
const
auto
&
axis
=
compute_op
->
axis
;
Tensor
t
=
op
.
output
(
0
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
vmap
[
axis
[
i
]
->
var
.
get
()]
=
TensorDimKey
(
t
,
i
);
...
...
@@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
};
for
(
auto
&
e
:
op
.
as
<
ComputeOpNode
>
()
->
body
)
{
for
(
auto
&
e
:
compute_op
->
body
)
{
ir
::
PostOrderVisit
(
e
,
fvisit
);
}
}
...
...
@@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
// prop exact reach back.
for
(
size_t
i
=
0
;
i
<
body
.
size
();
++
i
)
{
const
Operation
&
op
=
body
[
i
];
if
(
op
.
as
<
ScanOpNode
>
())
{
const
auto
&
update
=
op
.
as
<
ScanOpNode
>
()
->
update
;
const
auto
&
init
=
op
.
as
<
ScanOpNode
>
()
->
init
;
if
(
const
auto
*
scan_op
=
op
.
as
<
ScanOpNode
>
())
{
const
auto
&
update
=
scan_op
->
update
;
const
auto
&
init
=
scan_op
->
init
;
for
(
size_t
i
=
0
;
i
<
update
.
size
();
++
i
)
{
Tensor
t
=
op
.
output
(
i
);
for
(
size_t
k
=
1
;
i
<
update
[
i
]
->
shape
.
size
();
++
k
)
{
for
(
size_t
k
=
1
;
k
<
update
[
i
]
->
shape
.
size
();
++
k
)
{
f_merge_key
(
TensorDimKey
(
t
,
k
),
TensorDimKey
(
update
[
i
],
k
));
f_merge_key
(
TensorDimKey
(
t
,
k
),
TensorDimKey
(
init
[
i
],
k
));
}
}
}
else
if
(
op
.
as
<
ComputeOpNode
>
())
{
}
else
if
(
const
auto
*
compute_op
=
op
.
as
<
ComputeOpNode
>
())
{
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
TensorDimKey
>
>
vmap
;
const
auto
&
axis
=
op
.
as
<
ComputeOpNode
>
()
->
axis
;
const
auto
&
axis
=
compute_op
->
axis
;
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
std
::
vector
<
TensorDimKey
>
keys
;
for
(
int
j
=
0
;
j
<
op
->
num_outputs
();
++
j
)
{
...
...
@@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
};
for
(
auto
&
e
:
op
.
as
<
ComputeOpNode
>
()
->
body
)
{
for
(
auto
&
e
:
compute_op
->
body
)
{
ir
::
PostOrderVisit
(
e
,
fvisit
);
}
}
...
...
src/schedule/message_passing.cc
View file @
5b8ff8d0
...
...
@@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s,
using
HalideIR
::
Internal
::
can_prove
;
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
if
(
rel
.
as
<
SplitNode
>
())
{
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
();
if
(
const
SplitNode
*
s
=
rel
.
as
<
SplitNode
>
())
{
bool
outer
=
state
.
at
(
s
->
outer
);
bool
inner
=
state
.
at
(
s
->
inner
);
...
...
@@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s,
}
else
{
state
[
s
->
parent
]
=
true
;
}
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
();
}
else
if
(
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
())
{
bool
fused
=
state
.
at
(
s
->
fused
);
state
[
s
->
outer
]
=
fused
;
state
[
s
->
inner
]
=
fused
;
}
else
if
(
rel
.
as
<
RebaseNode
>
())
{
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
();
}
else
if
(
const
RebaseNode
*
s
=
rel
.
as
<
RebaseNode
>
())
{
state
[
s
->
parent
]
=
state
.
at
(
s
->
rebased
);
}
else
if
(
rel
.
as
<
SingletonNode
>
())
{
// nop
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
5b8ff8d0
...
...
@@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) {
const
ComputeOpNode
*
compute
=
s
->
op
.
as
<
ComputeOpNode
>
();
if
(
compute
)
{
if
(
!
new_body
[
j
].
size
())
{
new_body
[
j
]
=
s
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
new_body
[
j
]
=
compute
->
body
;
}
if
(
new_body
[
j
][
0
]
->
is_type
<
ir
::
Reduce
>
())
{
// specially handle reduction inline for multiplre reductions.
...
...
src/schedule/schedule_lang.cc
View file @
5b8ff8d0
...
...
@@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
n
->
stages
.
push_back
(
stage
);
n
->
stage_map
.
Set
(
op
,
stage
);
// mark scan updates.
if
(
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
();
if
(
const
ScanOpNode
*
scan
=
op
.
as
<
ScanOpNode
>
())
{
Array
<
Tensor
>
inputs
;
for
(
Tensor
t
:
scan
->
state_placeholder
)
{
inputs
.
push_back
(
t
);
...
...
src/schedule/schedule_ops.cc
View file @
5b8ff8d0
...
...
@@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator {
}
}
// Specially add replacements for scan op.
if
(
s
->
op
.
as
<
ScanOpNode
>
())
{
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
();
if
(
const
ScanOpNode
*
scan
=
s
->
op
.
as
<
ScanOpNode
>
())
{
for
(
size_t
i
=
0
;
i
<
scan
->
update
.
size
();
++
i
)
{
Tensor
t
=
s
->
origin_op
.
output
(
i
);
AddReplace
(
scan
->
init
[
i
],
t
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment