Commit 0876e9e9 by Tianqi Chen Committed by GitHub

[IR] Rename attr_key in AttrStmt (#83)

parent 8f51c5fd
Subproject commit 59fdca16978b6184bab87fbff7a00c95f1804686
Subproject commit d024efd80694556c1239c4435c5b3e70853a4896
......@@ -286,8 +286,8 @@ class Canonical::Internal : public IRMutator {
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->type_key == attr::thread_extent ||
op->type_key == attr::virtual_thread) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
++level_counter_;
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
......
......@@ -654,7 +654,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
}
void CodeGenC::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::thread_extent) {
if (op->attr_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
......@@ -667,11 +667,11 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
stream << ";\n";
}
}
} else if (op->type_key == ir::attr::storage_scope) {
} else if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
} else if (op->type_key == ir::attr::volatile_scope) {
} else if (op->attr_key == ir::attr::volatile_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
volatile_buf_.insert(v);
......
......@@ -1245,7 +1245,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
}
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::storage_scope) {
if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
......
......@@ -93,20 +93,20 @@ class PipelineExtractor: public IRVisitor {
}
void Visit_(const AttrStmt* op) final {
if (op->type_key == attr::pipeline_stage_scope) {
if (op->attr_key == attr::pipeline_stage_scope) {
CHECK(!in_pipeline_stage_);
in_pipeline_stage_ = true;
trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op);
trigger_.pop_back();
in_pipeline_stage_ = false;
} else if (op->type_key == attr::channel_read_advance ||
op->type_key == attr::channel_write_advance) {
} else if (op->attr_key == attr::channel_read_advance ||
op->attr_key == attr::channel_write_advance) {
trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op);
trigger_.pop_back();
} else if (op->type_key == attr::channel_read_scope ||
op->type_key == attr::channel_write_scope) {
} else if (op->attr_key == attr::channel_read_scope ||
op->attr_key == attr::channel_write_scope) {
Channel ch(op->node.node_);
ChannelEntry& cb = cmap_[ch->handle_var.get()];
if (cb.node != nullptr) {
......@@ -115,7 +115,7 @@ class PipelineExtractor: public IRVisitor {
cb.node = std::make_shared<ChannelBlockNode>();
cb.node->channel = ch;
}
if (op->type_key == attr::channel_read_scope) {
if (op->attr_key == attr::channel_read_scope) {
CHECK_EQ(cb.read_ref_count, 0)
<< "One channel can only be read from one consumer";
++cb.read_ref_count;
......@@ -173,7 +173,7 @@ class PipelineExtractor: public IRVisitor {
for (const auto& e : trigger_) {
const AttrStmt* attr = e.second;
Channel ch;
if (attr->type_key == attr::pipeline_stage_scope) {
if (attr->attr_key == attr::pipeline_stage_scope) {
ch = arg_write;
if (!ch.defined()) continue;
} else {
......@@ -195,10 +195,10 @@ class PipelineExtractor: public IRVisitor {
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size.
int trigger_size;
if (attr->type_key == attr::pipeline_stage_scope) {
if (attr->attr_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0));
} else if (attr->type_key == attr::channel_read_advance) {
} else if (attr->attr_key == attr::channel_read_advance) {
CHECK(arith::GetConstInt(attr->value, &trigger_size))
<< "Only support constant advance size";
cb.node->ctrl_signals.push_back(
......
......@@ -200,7 +200,7 @@ class VTInjector : public IRMutator {
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
return AttrStmt::make(op->node, op->attr_key, value, body);
}
}
}
......@@ -388,7 +388,7 @@ class VirtualThreadInjector : public IRMutator {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
if (op->type_key == attr::virtual_thread) {
if (op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
int nthread = static_cast<int>(op->value.as<IntImm>()->value);
VarTouchedAnalysis vs;
......
......@@ -68,7 +68,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
return AttrStmt::make(op->node, op->attr_key, value, body);
}
}
......
......@@ -25,9 +25,9 @@ class AllocateLifter : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->type_key != attr::virtual_thread)
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before LiftStorageAlloc";
if (op->type_key == attr::storage_scope) {
if (op->attr_key == attr::storage_scope) {
StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value);
allocs_[sc].emplace_back(
AttrStmt::make(
......@@ -35,7 +35,7 @@ class AllocateLifter : public IRMutator {
op->value, Evaluate::make(0)));
storage_scope_[op->node.get()] = sc;
return this->Mutate(op->body);
} else if (op->type_key == attr::thread_extent) {
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
......@@ -55,7 +55,7 @@ class AllocateLifter : public IRMutator {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->type_key, op->value, body);
op->node, op->attr_key, op->value, body);
}
}
return stmt;
......
......@@ -20,12 +20,12 @@ class ThreadAllreduceBuilder : public IRMutator {
: warp_size_(warp_size) {}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent) {
if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s);
thread_extents_.pop_back();
return ret;
} else if (op->type_key == attr::storage_scope) {
} else if (op->attr_key == attr::storage_scope) {
Stmt ret = IRMutator::Mutate_(op, s);
op = ret.as<AttrStmt>();
const Variable* v = op->node.as<Variable>();
......
......@@ -107,14 +107,14 @@ class ChannelAccessRewriter : public IRMutator {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret;
const AttrStmt* adv = op->body.as<AttrStmt>();
if ((op->type_key == ir::attr::channel_read_scope &&
adv && adv->type_key == ir::attr::channel_read_advance) ||
(op->type_key == ir::attr::channel_write_scope &&
adv && adv->type_key == ir::attr::channel_write_advance)) {
if ((op->attr_key == ir::attr::channel_read_scope &&
adv && adv->attr_key == ir::attr::channel_read_advance) ||
(op->attr_key == ir::attr::channel_write_scope &&
adv && adv->attr_key == ir::attr::channel_write_advance)) {
RewriteEntry e;
e.window = op;
e.advance = adv;
e.read_access = op->type_key == ir::attr::channel_read_scope;
e.read_access = op->attr_key == ir::attr::channel_read_scope;
tasks_.push_back(e);
ret = IRMutator::Mutate_(op, s);
if (tasks_.back().rewrite_success) {
......
......@@ -18,7 +18,7 @@ namespace ir {
class IRUseDefAnalysis : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent) {
if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
......@@ -35,9 +35,9 @@ class IRUseDefAnalysis : public IRMutator {
}
Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s;
return AttrStmt::make(op->node, op->type_key, value, body);
} else if (op->type_key == attr::channel_write_scope ||
op->type_key == attr::channel_read_scope) {
return AttrStmt::make(op->node, op->attr_key, value, body);
} else if (op->attr_key == attr::channel_write_scope ||
op->attr_key == attr::channel_read_scope) {
Channel ch(op->node.node_);
if (!use_count_.count(ch->handle_var.get())) {
this->HandleDef(ch->handle_var.get());
......@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator {
class HostDeviceSplitter : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent ||
op->type_key == attr::pipeline_exec_scope) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope) {
return SplitDeviceFunc(s);
}
return IRMutator::Mutate_(op, s);
......
......@@ -77,7 +77,7 @@ class MarkChannelAccess : public IRMutator {
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == ir::attr::storage_scope) {
if (op->attr_key == ir::attr::storage_scope) {
Var buf_var(op->node.node_);
if (cmap_.count(buf_var.get())) return Mutate(op->body);
}
......@@ -223,7 +223,7 @@ class StageSplitter : public IRMutator {
nest.emplace_back(IfThenElse::make(op->condition, no_op));
} else if (const AttrStmt* op = s.as<AttrStmt>()) {
nest.emplace_back(AttrStmt::make(
op->node, op->type_key, op->value, no_op));
op->node, op->attr_key, op->value, no_op));
} else if (s.as<ProducerConsumer>()) {
} else if (s.as<Block>()) {
} else if (const Allocate* op = s.as<Allocate>()) {
......@@ -266,7 +266,7 @@ class PipelineSplitter : public IRMutator {
: split_load_(split_load) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == ir::attr::pipeline_exec_scope) {
if (op->attr_key == ir::attr::pipeline_exec_scope) {
CHECK_LE(env_.size(), 1U);
const ProducerConsumer* env = nullptr;
if (env_.size() == 1) {
......@@ -276,7 +276,7 @@ class PipelineSplitter : public IRMutator {
op->body, env);
if (body.same_as(op->body)) return s;
return AttrStmt::make(
op->node, op->type_key, op->value, body);
op->node, op->attr_key, op->value, body);
} else {
return IRMutator::Mutate_(op, s);
}
......
......@@ -40,17 +40,17 @@ class StorageFlattener : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::realize_scope) {
if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
} else if (op->type_key == attr::thread_extent) {
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
return stmt;
} else if (op->type_key == attr::extern_op_scope) {
} else if (op->attr_key == attr::extern_op_scope) {
return HandleExternOp(op);
}
return IRMutator::Mutate_(op, s);
......
......@@ -57,7 +57,7 @@ class StorageSyncPlanner : public IRVisitor {
allow_load_ = false;
}
void Visit_(const AttrStmt* op) final {
if (op->type_key == "storage_scope") {
if (op->attr_key == "storage_scope") {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
......
......@@ -55,7 +55,7 @@ class InjectAttach : public IRMutator {
stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::loop_scope) {
op->attr_key == attr::loop_scope) {
if (attach_spec_->attach_type == kScope &&
op->node == attach_spec_->attach_ivar) {
CHECK(!found_attach)
......@@ -63,7 +63,7 @@ class InjectAttach : public IRMutator {
<< " in multiple places in the IR";
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body));
}
}
......@@ -97,12 +97,12 @@ class InjectScanStep : public IRMutator {
// update
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
((op->type_key == attr::scan_update_scope && !is_init_) ||
(op->type_key == attr::scan_init_scope && is_init_))) {
((op->attr_key == attr::scan_update_scope && !is_init_) ||
(op->attr_key == attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body));
}
}
......@@ -150,20 +150,20 @@ class SchedulePostProc : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope ||
op->type_key == attr::scan_init_scope) {
if (op->attr_key == attr::loop_scope ||
op->attr_key == attr::scan_init_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) {
} else if (op->attr_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body);
} else if (op->type_key == ir::attr::realize_scope) {
} else if (op->attr_key == ir::attr::realize_scope) {
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
Stmt ret = AttrStmt::make(
it->second, op->type_key, op->value, op->body);
it->second, op->attr_key, op->value, op->body);
return this->Mutate(ret);
} else {
return this->Mutate(op->body);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment