Commit 5b8ff8d0 by Alexey Romanov Committed by Tianqi Chen

Remove duplicate as Checks and CHECK value (#2531)

parent 74b035a2
...@@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { ...@@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
void CodeGenC::VisitStmt_(const AssertStmt* op) { void CodeGenC::VisitStmt_(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition); std::string cond = PrintExpr(op->condition);
PrintIndent(); PrintIndent();
if (op->message.as<StringImm>()) { if (const auto* str = op->message.as<StringImm>()) {
// GLOG style check // GLOG style check
stream << "CHECK(" << cond << ") << \"" stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
<< op->message.as<StringImm>()->value << "\";\n";
} else { } else {
stream << "assert(" << cond << ");\n"; stream << "assert(" << cond << ");\n";
} }
......
...@@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) { ...@@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
} }
void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
if (op->message.as<StringImm>()) { if (const auto* str = op->message.as<StringImm>()) {
int sid = this->GetStrID(op->message.as<StringImm>()->value); int sid = this->GetStrID(str->value);
this->Push(op->condition); this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid); this->PushOp(StackVM::ASSERT, sid);
} }
......
...@@ -42,13 +42,13 @@ inline std::string GetHostName() { ...@@ -42,13 +42,13 @@ inline std::string GetHostName() {
} }
/*! /*!
* \brief Common data structure fornetwork address. * \brief Common data structure for network address.
*/ */
struct SockAddr { struct SockAddr {
sockaddr_storage addr; sockaddr_storage addr;
SockAddr() {} SockAddr() {}
/*! /*!
* \brief construc address by url and port * \brief construct address by url and port
* \param url The url of the address * \param url The url of the address
* \param port The port of the address. * \param port The port of the address.
*/ */
......
...@@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage, ...@@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
// Gather rebased variables // Gather rebased variables
std::unordered_map<IterVar, IterVar> rebased; std::unordered_map<IterVar, IterVar> rebased;
for (auto rel : stage->relations) { for (auto rel : stage->relations) {
if (auto rebase = rel.as<RebaseNode>()) { if (const auto* rebase = rel.as<RebaseNode>()) {
rebased[rebase->rebased] = rebase->parent; rebased[rebase->rebased] = rebase->parent;
CHECK(rebase->parent->dom.defined()); CHECK(rebase->parent->dom.defined());
CHECK(dom_map.count(rebase->rebased)); CHECK(dom_map.count(rebase->rebased));
......
...@@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { ...@@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration // use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri; Stmt s = *ri;
if (s.as<For>()) { if (const auto* for_ = s.as<For>()) {
auto n = make_node<For>(*s.as<For>()); auto n = make_node<For>(*for_);
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<LetStmt>()) { } else if (const auto* let = s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*s.as<LetStmt>()); auto n = make_node<LetStmt>(*let);
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<AttrStmt>()) { } else if (const auto* attr = s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*s.as<AttrStmt>()); auto n = make_node<AttrStmt>(*attr);
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<IfThenElse>()) { } else if (const auto* ite = s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*s.as<IfThenElse>()); auto n = make_node<IfThenElse>(*ite);
CHECK(is_no_op(n->then_case)); CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined()); CHECK(!n->else_case.defined());
n->then_case = body; n->then_case = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<Block>()) { } else if (const auto* block = s.as<Block>()) {
auto n = make_node<Block>(*s.as<Block>()); auto n = make_node<Block>(*block);
CHECK(is_no_op(n->rest)); CHECK(is_no_op(n->rest));
n->rest = body; n->rest = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<AssertStmt>()) { } else if (const auto* assert_ = s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*s.as<AssertStmt>()); auto n = make_node<AssertStmt>(*assert_);
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<Allocate>()) { } else if (const auto* alloc = s.as<Allocate>()) {
auto n = make_node<Allocate>(*s.as<Allocate>()); auto n = make_node<Allocate>(*alloc);
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
......
...@@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin; Expr body_begin;
Stmt pre_stmt; Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) { arith::Interval true_itrv_i = true_itrv.as<arith::IntervalSet>()->i;
if (true_itrv_i.has_lower_bound()) {
body_begin = ir::Simplify(true_itrv.min()); body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) { if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0); Expr cond = (body_begin - min >= 0);
...@@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin; Expr post_doubt_begin;
Stmt post_stmt; Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) { if (true_itrv_i.has_upper_bound()) {
post_doubt_begin = ir::Simplify(true_itrv.max() + 1); post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) { if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative // require the extent to be non-negative
......
...@@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator {
value = this->Mutate(value); value = this->Mutate(value);
} }
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s; if (value.same_as(op->value) && body.same_as(op->body)) return s;
return AttrStmt::make(op->node, op->attr_key, value, body); return AttrStmt::make(op->node, op->attr_key, value, body);
} else if (op->attr_key == attr::channel_write_scope || } else if (op->attr_key == attr::channel_write_scope ||
op->attr_key == attr::channel_read_scope) { op->attr_key == attr::channel_read_scope) {
......
...@@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator { ...@@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator {
src_entry->attach_scope_ == thread_scope_ && src_entry->attach_scope_ == thread_scope_ &&
src_entry->elem_type == ae.alloc->type.element_of() && src_entry->elem_type == ae.alloc->type.element_of() &&
visitor.Check(s.stmt, var, src)) { visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits = static_cast<uint64_t>( uint64_t const_nbits =
ae.alloc->constant_allocation_size() * static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
ae.alloc->type.bits() * ae.alloc->type.bits() *
ae.alloc->type.lanes()); ae.alloc->type.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) { if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace // successfully inplace
dst_entry = src_entry; dst_entry = src_entry;
......
...@@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {
void Visit_(const AttrStmt *op) { void Visit_(const AttrStmt *op) {
if (op->attr_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
if (op->value.as<StringImm>()->value == "local") { std::string op_value = op->value.as<StringImm>()->value;
if (op_value == "local") {
visited_local_buffers_.insert(op->node.as<tvm::Variable>()); visited_local_buffers_.insert(op->node.as<tvm::Variable>());
} else if (op->value.as<StringImm>()->value == "shared") { } else if (op_value == "shared") {
visited_shared_buffers_.insert(op->node.as<tvm::Variable>()); visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
} }
} else if (op->attr_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
...@@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt, ...@@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
int64_t max_thread_z = INT64_MAX; int64_t max_thread_z = INT64_MAX;
for (auto iter : constraints) { for (auto iter : constraints) {
const IntImm* val = iter.second.as<IntImm>();
if (iter.first == "max_local_memory_per_block") if (iter.first == "max_local_memory_per_block")
max_local_memory_per_block = (iter.second).as<IntImm>()->value; max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block") else if (iter.first == "max_shared_memory_per_block")
max_shared_memory_per_block = (iter.second).as<IntImm>()->value; max_shared_memory_per_block = val->value;
else if (iter.first == "max_threads_per_block") else if (iter.first == "max_threads_per_block")
max_threads_per_block = (iter.second).as<IntImm>()->value; max_threads_per_block = val->value;
else if (iter.first == "max_thread_x") else if (iter.first == "max_thread_x")
max_thread_x = (iter.second).as<IntImm>()->value; max_thread_x = val->value;
else if (iter.first == "max_thread_y") else if (iter.first == "max_thread_y")
max_thread_y = (iter.second).as<IntImm>()->value; max_thread_y = val->value;
else if (iter.first == "max_thread_z") else if (iter.first == "max_thread_z")
max_thread_z = (iter.second).as<IntImm>()->value; max_thread_z = val->value;
else else
LOG(FATAL) << "Invalid check item: " << iter.first; LOG(FATAL) << "Invalid check item: " << iter.first;
} }
......
...@@ -379,7 +379,7 @@ class Interpreter : ...@@ -379,7 +379,7 @@ class Interpreter :
// //
// We have some functions cotaining chunks of operators // We have some functions cotaining chunks of operators
// which will be loaded into operator map. // which will be loaded into operator map.
if (auto op_node = call->op.as<OpNode>()) { if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try " << "; operators should be removed by future passes; try "
"fusing and lowering"; "fusing and lowering";
......
...@@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) { ...@@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
auto sn = source_map.find(name); auto sn = source_map.find(name);
if (sn == source_map.end()) { if (sn == source_map.end()) {
NodePtr<SourceNameNode> n = make_node<SourceNameNode>(); NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
n->name = std::move(name);
source_map[name] = n; source_map[name] = n;
n->name = std::move(name);
return n; return n;
} else { } else {
return sn->second; return sn->second;
......
...@@ -15,7 +15,7 @@ namespace tvm { ...@@ -15,7 +15,7 @@ namespace tvm {
namespace relay { namespace relay {
TensorType ToTensorType(const Type& t) { TensorType ToTensorType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) { if (const auto* tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node); return GetRef<TensorType>(tt_node);
} else { } else {
return TensorType(nullptr); return TensorType(nullptr);
......
...@@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, ...@@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
rnode->scale = slhs->scale; rnode->scale = slhs->scale;
rnode->axes = slhs->axes; rnode->axes = slhs->axes;
} else { } else {
CHECK(slhs != nullptr); CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ExpandBiasToMatchAxis( Expr scale = ExpandBiasToMatchAxis(
srhs->scale, trhs->shape.size(), srhs->axes); srhs->scale, trhs->shape.size(), srhs->axes);
......
...@@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) { ...@@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {
//! \brief if the expression is a GlobalVar, transform to it's expression. //! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) { Expr DeGlobal(const Module& mod, const Expr& e) {
if (auto x = e.as<GlobalVarNode>()) { if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body; return mod->Lookup(GetRef<GlobalVar>(x))->body;
} else { } else {
return e; return e;
......
...@@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { ...@@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
} }
Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) { Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (auto f = e.as<FunctionNode>()) { if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params, return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv), ToANFAux(f->body, m, gv),
f->ret_type, f->ret_type,
......
...@@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
for (auto cs : fn_ty->type_constraints) { for (auto cs : fn_ty->type_constraints) {
if (auto tr = cs.as<TypeRelationNode>()) { if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint( solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs), TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call)); GetRef<Call>(call));
......
...@@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) { ...@@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
// Add type constraint to the solver. // Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
if (auto *op = constraint.as<TypeRelationNode>()) { if (const auto* op = constraint.as<TypeRelationNode>()) {
// create a new relation node. // create a new relation node.
RelationNode* rnode = arena_.make<RelationNode>(); RelationNode* rnode = arena_.make<RelationNode>();
rnode->location = loc; rnode->location = loc;
......
...@@ -486,29 +486,28 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -486,29 +486,28 @@ class RPCSession::EventHandler : public dmlc::Stream {
arg_recv_stage_ = 1; arg_recv_stage_ = 1;
this->RequestBytes(len); this->RequestBytes(len);
break; break;
break; }
} case kArrayHandle: {
case kArrayHandle: { temp_array_.reset(new RPCDataArrayBuffer());
temp_array_.reset(new RPCDataArrayBuffer()); uint64_t handle;
uint64_t handle; this->Read(&handle);
this->Read(&handle); DLTensor& tensor = temp_array_->tensor;
DLTensor& tensor = temp_array_->tensor; tensor.data = reinterpret_cast<void*>(handle);
tensor.data = reinterpret_cast<void*>(handle); this->Read(&(tensor.ctx));
this->Read(&(tensor.ctx)); this->Read(&(tensor.ndim));
this->Read(&(tensor.ndim)); this->Read(&(tensor.dtype));
this->Read(&(tensor.dtype)); temp_array_->shape.resize(tensor.ndim);
temp_array_->shape.resize(tensor.ndim); tensor.shape = temp_array_->shape.data();
tensor.shape = temp_array_->shape.data(); arg_recv_stage_ = 1;
arg_recv_stage_ = 1; tensor.strides = nullptr;
tensor.strides = nullptr; tensor.byte_offset = 0;
tensor.byte_offset = 0; this->RequestBytes(sizeof(int64_t) * tensor.ndim);
this->RequestBytes(sizeof(int64_t) * tensor.ndim); break;
break; }
} default: {
default: { LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); break;
break; }
}
} }
} else { } else {
CHECK_EQ(arg_recv_stage_, 1); CHECK_EQ(arg_recv_stage_, 1);
......
...@@ -406,7 +406,6 @@ void StackVM::Run(State* s) const { ...@@ -406,7 +406,6 @@ void StackVM::Run(State* s) const {
case intrinsic::kArrByteOffset: { case intrinsic::kArrByteOffset: {
stack[sp].v_int64 = static_cast<int64_t>( stack[sp].v_int64 = static_cast<int64_t>(
arr[index].byte_offset); break; arr[index].byte_offset); break;
break;
} }
case intrinsic::kArrDeviceId: { case intrinsic::kArrDeviceId: {
stack[sp].v_int64 = arr[index].ctx.device_id; break; stack[sp].v_int64 = arr[index].ctx.device_id; break;
...@@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const { ...@@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
if (f == nullptr) { if (f == nullptr) {
CHECK(s->mod_ctx != nullptr) CHECK(s->mod_ctx != nullptr)
<< "No local context is set in stackvm"; << "No local context is set in stackvm";
CHECK(s->mod_ctx != nullptr);
const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]); const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
CHECK(pf != nullptr); CHECK(pf != nullptr);
f = *pf; f = *pf;
......
...@@ -331,7 +331,7 @@ class StackVM { ...@@ -331,7 +331,7 @@ class StackVM {
case EQ_I64: return EQ_F64; case EQ_I64: return EQ_F64;
case LT_I64: return LT_F64; case LT_I64: return LT_F64;
case LE_I64: return LE_F64; case LE_I64: return LE_F64;
case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64;
default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64; default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
} }
} }
......
...@@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
} }
for (Operation op : ops) { for (Operation op : ops) {
if (op.as<ScanOpNode>()) { if (const auto* scan_op = op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update; const auto& update = scan_op->update;
const auto& init = op.as<ScanOpNode>()->init; const auto& init = scan_op->init;
for (size_t i = 0; i < update.size(); ++i) { for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i); Tensor t = op.output(i);
for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) { for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
...@@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
TensorDimKey(init[i], k)); TensorDimKey(init[i], k));
} }
} }
} else if (op.as<ComputeOpNode>()) { } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap; std::unordered_map<const Node*, TensorDimKey> vmap;
const auto& axis = op.as<ComputeOpNode>()->axis; const auto& axis = compute_op->axis;
Tensor t = op.output(0); Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i); vmap[axis[i]->var.get()] = TensorDimKey(t, i);
...@@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
} }
} }
}; };
for (auto& e : op.as<ComputeOpNode>()->body) { for (auto& e : compute_op->body) {
ir::PostOrderVisit(e, fvisit); ir::PostOrderVisit(e, fvisit);
} }
} }
...@@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { ...@@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
// prop exact reach back. // prop exact reach back.
for (size_t i = 0; i < body.size(); ++i) { for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i]; const Operation& op = body[i];
if (op.as<ScanOpNode>()) { if (const auto* scan_op = op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update; const auto& update = scan_op->update;
const auto& init = op.as<ScanOpNode>()->init; const auto& init = scan_op->init;
for (size_t i = 0; i < update.size(); ++i) { for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i); Tensor t = op.output(i);
for (size_t k = 1; i < update[i]->shape.size(); ++k) { for (size_t k = 1; k < update[i]->shape.size(); ++k) {
f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k)); f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k)); f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
} }
} }
} else if (op.as<ComputeOpNode>()) { } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap; std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
const auto& axis = op.as<ComputeOpNode>()->axis; const auto& axis = compute_op->axis;
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
std::vector<TensorDimKey> keys; std::vector<TensorDimKey> keys;
for (int j = 0; j < op->num_outputs(); ++j) { for (int j = 0; j < op->num_outputs(); ++j) {
...@@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { ...@@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
} }
} }
}; };
for (auto& e : op.as<ComputeOpNode>()->body) { for (auto& e : compute_op->body) {
ir::PostOrderVisit(e, fvisit); ir::PostOrderVisit(e, fvisit);
} }
} }
......
...@@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s, ...@@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s,
using HalideIR::Internal::can_prove; using HalideIR::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) { for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1]; IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) { if (const SplitNode* s = rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer); bool outer = state.at(s->outer);
bool inner = state.at(s->inner); bool inner = state.at(s->inner);
...@@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s, ...@@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s,
} else { } else {
state[s->parent] = true; state[s->parent] = true;
} }
} else if (rel.as<FuseNode>()) { } else if (const FuseNode* s = rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused); bool fused = state.at(s->fused);
state[s->outer] = fused; state[s->outer] = fused;
state[s->inner] = fused; state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) { } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased); state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) { } else if (rel.as<SingletonNode>()) {
// nop // nop
......
...@@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) {
const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) { if (compute) {
if (!new_body[j].size()) { if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body; new_body[j] = compute->body;
} }
if (new_body[j][0]->is_type<ir::Reduce>()) { if (new_body[j][0]->is_type<ir::Reduce>()) {
// specially handle reduction inline for multiplre reductions. // specially handle reduction inline for multiplre reductions.
......
...@@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) { ...@@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
n->stages.push_back(stage); n->stages.push_back(stage);
n->stage_map.Set(op, stage); n->stage_map.Set(op, stage);
// mark scan updates. // mark scan updates.
if (op.as<ScanOpNode>()) { if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs; Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) { for (Tensor t : scan->state_placeholder) {
inputs.push_back(t); inputs.push_back(t);
......
...@@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator { ...@@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator {
} }
} }
// Specially add replacements for scan op. // Specially add replacements for scan op.
if (s->op.as<ScanOpNode>()) { if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
for (size_t i = 0; i < scan->update.size(); ++i) { for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->origin_op.output(i); Tensor t = s->origin_op.output(i);
AddReplace(scan->init[i], t); AddReplace(scan->init[i], t);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment