Commit 5b8ff8d0 by Alexey Romanov Committed by Tianqi Chen

Remove duplicate as Checks and CHECK value (#2531)

parent 74b035a2
......@@ -791,10 +791,9 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
void CodeGenC::VisitStmt_(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (op->message.as<StringImm>()) {
if (const auto* str = op->message.as<StringImm>()) {
// GLOG style check
stream << "CHECK(" << cond << ") << \""
<< op->message.as<StringImm>()->value << "\";\n";
stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
} else {
stream << "assert(" << cond << ");\n";
}
......
......@@ -470,8 +470,8 @@ void CodeGenStackVM::VisitExpr_(const Select *op) {
}
void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
if (op->message.as<StringImm>()) {
int sid = this->GetStrID(op->message.as<StringImm>()->value);
if (const auto* str = op->message.as<StringImm>()) {
int sid = this->GetStrID(str->value);
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
}
......
......@@ -42,13 +42,13 @@ inline std::string GetHostName() {
}
/*!
* \brief Common data structure fornetwork address.
* \brief Common data structure for network address.
*/
struct SockAddr {
sockaddr_storage addr;
SockAddr() {}
/*!
* \brief construc address by url and port
* \brief construct address by url and port
* \param url The url of the address
* \param port The port of the address.
*/
......
......@@ -435,7 +435,7 @@ Stmt ApplySchedule(const Stage &stage,
// Gather rebased variables
std::unordered_map<IterVar, IterVar> rebased;
for (auto rel : stage->relations) {
if (auto rebase = rel.as<RebaseNode>()) {
if (const auto* rebase = rel.as<RebaseNode>()) {
rebased[rebase->rebased] = rebase->parent;
CHECK(rebase->parent->dom.defined());
CHECK(dom_map.count(rebase->rebased));
......
......@@ -12,39 +12,39 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = make_node<For>(*s.as<For>());
if (const auto* for_ = s.as<For>()) {
auto n = make_node<For>(*for_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*s.as<LetStmt>());
} else if (const auto* let = s.as<LetStmt>()) {
auto n = make_node<LetStmt>(*let);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*s.as<AttrStmt>());
} else if (const auto* attr = s.as<AttrStmt>()) {
auto n = make_node<AttrStmt>(*attr);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*s.as<IfThenElse>());
} else if (const auto* ite = s.as<IfThenElse>()) {
auto n = make_node<IfThenElse>(*ite);
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<Block>()) {
auto n = make_node<Block>(*s.as<Block>());
} else if (const auto* block = s.as<Block>()) {
auto n = make_node<Block>(*block);
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*s.as<AssertStmt>());
} else if (const auto* assert_ = s.as<AssertStmt>()) {
auto n = make_node<AssertStmt>(*assert_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<Allocate>()) {
auto n = make_node<Allocate>(*s.as<Allocate>());
} else if (const auto* alloc = s.as<Allocate>()) {
auto n = make_node<Allocate>(*alloc);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
......
......@@ -326,7 +326,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
arith::Interval true_itrv_i = true_itrv.as<arith::IntervalSet>()->i;
if (true_itrv_i.has_lower_bound()) {
body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
......@@ -347,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
if (true_itrv_i.has_upper_bound()) {
post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative
......
......@@ -34,7 +34,7 @@ class IRUseDefAnalysis : public IRMutator {
value = this->Mutate(value);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s;
if (value.same_as(op->value) && body.same_as(op->body)) return s;
return AttrStmt::make(op->node, op->attr_key, value, body);
} else if (op->attr_key == attr::channel_write_scope ||
op->attr_key == attr::channel_read_scope) {
......
......@@ -718,10 +718,10 @@ class StoragePlanRewriter : public IRMutator {
src_entry->attach_scope_ == thread_scope_ &&
src_entry->elem_type == ae.alloc->type.element_of() &&
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits = static_cast<uint64_t>(
ae.alloc->constant_allocation_size() *
uint64_t const_nbits =
static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
ae.alloc->type.bits() *
ae.alloc->type.lanes());
ae.alloc->type.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
......
......@@ -73,9 +73,10 @@ class GPUCodeVerifier : public IRVisitor {
void Visit_(const AttrStmt *op) {
if (op->attr_key == attr::storage_scope) {
if (op->value.as<StringImm>()->value == "local") {
std::string op_value = op->value.as<StringImm>()->value;
if (op_value == "local") {
visited_local_buffers_.insert(op->node.as<tvm::Variable>());
} else if (op->value.as<StringImm>()->value == "shared") {
} else if (op_value == "shared") {
visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
}
} else if (op->attr_key == attr::thread_extent) {
......@@ -159,18 +160,19 @@ bool VerifyGPUCode(Stmt stmt,
int64_t max_thread_z = INT64_MAX;
for (auto iter : constraints) {
const IntImm* val = iter.second.as<IntImm>();
if (iter.first == "max_local_memory_per_block")
max_local_memory_per_block = (iter.second).as<IntImm>()->value;
max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block")
max_shared_memory_per_block = (iter.second).as<IntImm>()->value;
max_shared_memory_per_block = val->value;
else if (iter.first == "max_threads_per_block")
max_threads_per_block = (iter.second).as<IntImm>()->value;
max_threads_per_block = val->value;
else if (iter.first == "max_thread_x")
max_thread_x = (iter.second).as<IntImm>()->value;
max_thread_x = val->value;
else if (iter.first == "max_thread_y")
max_thread_y = (iter.second).as<IntImm>()->value;
max_thread_y = val->value;
else if (iter.first == "max_thread_z")
max_thread_z = (iter.second).as<IntImm>()->value;
max_thread_z = val->value;
else
LOG(FATAL) << "Invalid check item: " << iter.first;
}
......
......@@ -379,7 +379,7 @@ class Interpreter :
//
// We have some functions cotaining chunks of operators
// which will be loaded into operator map.
if (auto op_node = call->op.as<OpNode>()) {
if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try "
"fusing and lowering";
......
......@@ -20,8 +20,8 @@ NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
auto sn = source_map.find(name);
if (sn == source_map.end()) {
NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
n->name = std::move(name);
source_map[name] = n;
n->name = std::move(name);
return n;
} else {
return sn->second;
......
......@@ -15,7 +15,7 @@ namespace tvm {
namespace relay {
TensorType ToTensorType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) {
if (const auto* tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node);
} else {
return TensorType(nullptr);
......
......@@ -361,7 +361,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
} else {
CHECK(slhs != nullptr);
CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ExpandBiasToMatchAxis(
srhs->scale, trhs->shape.size(), srhs->axes);
......
......@@ -61,7 +61,7 @@ Type WithGradientType(const Type& t) {
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (auto x = e.as<GlobalVarNode>()) {
if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body;
} else {
return e;
......
......@@ -385,7 +385,7 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
}
Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (auto f = e.as<FunctionNode>()) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv),
f->ret_type,
......
......@@ -386,7 +386,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
for (auto cs : fn_ty->type_constraints) {
if (auto tr = cs.as<TypeRelationNode>()) {
if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call));
......
......@@ -376,7 +376,7 @@ void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
if (auto *op = constraint.as<TypeRelationNode>()) {
if (const auto* op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = arena_.make<RelationNode>();
rnode->location = loc;
......
......@@ -486,29 +486,28 @@ class RPCSession::EventHandler : public dmlc::Stream {
arg_recv_stage_ = 1;
this->RequestBytes(len);
break;
break;
}
case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle;
this->Read(&handle);
DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx));
this->Read(&(tensor.ndim));
this->Read(&(tensor.dtype));
temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1;
tensor.strides = nullptr;
tensor.byte_offset = 0;
this->RequestBytes(sizeof(int64_t) * tensor.ndim);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
case kArrayHandle: {
temp_array_.reset(new RPCDataArrayBuffer());
uint64_t handle;
this->Read(&handle);
DLTensor& tensor = temp_array_->tensor;
tensor.data = reinterpret_cast<void*>(handle);
this->Read(&(tensor.ctx));
this->Read(&(tensor.ndim));
this->Read(&(tensor.dtype));
temp_array_->shape.resize(tensor.ndim);
tensor.shape = temp_array_->shape.data();
arg_recv_stage_ = 1;
tensor.strides = nullptr;
tensor.byte_offset = 0;
this->RequestBytes(sizeof(int64_t) * tensor.ndim);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
} else {
CHECK_EQ(arg_recv_stage_, 1);
......
......@@ -406,7 +406,6 @@ void StackVM::Run(State* s) const {
case intrinsic::kArrByteOffset: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].byte_offset); break;
break;
}
case intrinsic::kArrDeviceId: {
stack[sp].v_int64 = arr[index].ctx.device_id; break;
......@@ -531,7 +530,6 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
if (f == nullptr) {
CHECK(s->mod_ctx != nullptr)
<< "No local context is set in stackvm";
CHECK(s->mod_ctx != nullptr);
const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
CHECK(pf != nullptr);
f = *pf;
......
......@@ -331,7 +331,7 @@ class StackVM {
case EQ_I64: return EQ_F64;
case LT_I64: return LT_F64;
case LE_I64: return LE_F64;
case MOD_I64: LOG(FATAL) << "cannot handle mod for float";
case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64;
default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
}
}
......
......@@ -223,9 +223,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
for (Operation op : ops) {
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
if (const auto* scan_op = op.as<ScanOpNode>()) {
const auto& update = scan_op->update;
const auto& init = scan_op->init;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
......@@ -235,9 +235,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
TensorDimKey(init[i], k));
}
}
} else if (op.as<ComputeOpNode>()) {
} else if (const auto* compute_op = op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
const auto& axis = compute_op->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i);
......@@ -260,7 +260,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
};
for (auto& e : op.as<ComputeOpNode>()->body) {
for (auto& e : compute_op->body) {
ir::PostOrderVisit(e, fvisit);
}
}
......@@ -312,19 +312,19 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
// prop exact reach back.
for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i];
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
if (const auto* scan_op = op.as<ScanOpNode>()) {
const auto& update = scan_op->update;
const auto& init = scan_op->init;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (size_t k = 1; i < update[i]->shape.size(); ++k) {
for (size_t k = 1; k < update[i]->shape.size(); ++k) {
f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
}
}
} else if (op.as<ComputeOpNode>()) {
} else if (const auto* compute_op = op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
const auto& axis = compute_op->axis;
for (size_t i = 0; i < axis.size(); ++i) {
std::vector<TensorDimKey> keys;
for (int j = 0; j < op->num_outputs(); ++j) {
......@@ -352,7 +352,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
};
for (auto& e : op.as<ComputeOpNode>()->body) {
for (auto& e : compute_op->body) {
ir::PostOrderVisit(e, fvisit);
}
}
......
......@@ -419,8 +419,7 @@ void PassUpBoundCheck(const Stage& s,
using HalideIR::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
if (const SplitNode* s = rel.as<SplitNode>()) {
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
......@@ -439,13 +438,11 @@ void PassUpBoundCheck(const Stage& s,
} else {
state[s->parent] = true;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
} else if (const FuseNode* s = rel.as<FuseNode>()) {
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
......
......@@ -544,7 +544,7 @@ void InjectInline(ScheduleNode* sch) {
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) {
if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body;
new_body[j] = compute->body;
}
if (new_body[j][0]->is_type<ir::Reduce>()) {
// specially handle reduction inline for multiplre reductions.
......
......@@ -710,8 +710,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
......
......@@ -304,8 +304,7 @@ class SchedulePostProc : public IRMutator {
}
}
// Specially add replacements for scan op.
if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->origin_op.output(i);
AddReplace(scan->init[i], t);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment