Commit 0876e9e9 by Tianqi Chen Committed by GitHub

[IR] Rename attr_key in AttrStmt (#83)

parent 8f51c5fd
Subproject commit 59fdca16978b6184bab87fbff7a00c95f1804686 Subproject commit d024efd80694556c1239c4435c5b3e70853a4896
...@@ -286,8 +286,8 @@ class Canonical::Internal : public IRMutator { ...@@ -286,8 +286,8 @@ class Canonical::Internal : public IRMutator {
} }
// AttrStmt // AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) { Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->type_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->type_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
++level_counter_; ++level_counter_;
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
......
...@@ -654,7 +654,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { ...@@ -654,7 +654,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
} }
void CodeGenC::VisitStmt_(const AttrStmt* op) { void CodeGenC::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::thread_extent) { if (op->attr_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) { if (!var_idmap_.count(iv->var.get())) {
...@@ -667,11 +667,11 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { ...@@ -667,11 +667,11 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
stream << ";\n"; stream << ";\n";
} }
} }
} else if (op->type_key == ir::attr::storage_scope) { } else if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value; alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
} else if (op->type_key == ir::attr::volatile_scope) { } else if (op->attr_key == ir::attr::volatile_scope) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
volatile_buf_.insert(v); volatile_buf_.insert(v);
......
...@@ -1245,7 +1245,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1245,7 +1245,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
} }
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::storage_scope) { if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value; alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
......
...@@ -93,20 +93,20 @@ class PipelineExtractor: public IRVisitor { ...@@ -93,20 +93,20 @@ class PipelineExtractor: public IRVisitor {
} }
void Visit_(const AttrStmt* op) final { void Visit_(const AttrStmt* op) final {
if (op->type_key == attr::pipeline_stage_scope) { if (op->attr_key == attr::pipeline_stage_scope) {
CHECK(!in_pipeline_stage_); CHECK(!in_pipeline_stage_);
in_pipeline_stage_ = true; in_pipeline_stage_ = true;
trigger_.emplace_back(std::make_pair(loop_.size(), op)); trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
trigger_.pop_back(); trigger_.pop_back();
in_pipeline_stage_ = false; in_pipeline_stage_ = false;
} else if (op->type_key == attr::channel_read_advance || } else if (op->attr_key == attr::channel_read_advance ||
op->type_key == attr::channel_write_advance) { op->attr_key == attr::channel_write_advance) {
trigger_.emplace_back(std::make_pair(loop_.size(), op)); trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
trigger_.pop_back(); trigger_.pop_back();
} else if (op->type_key == attr::channel_read_scope || } else if (op->attr_key == attr::channel_read_scope ||
op->type_key == attr::channel_write_scope) { op->attr_key == attr::channel_write_scope) {
Channel ch(op->node.node_); Channel ch(op->node.node_);
ChannelEntry& cb = cmap_[ch->handle_var.get()]; ChannelEntry& cb = cmap_[ch->handle_var.get()];
if (cb.node != nullptr) { if (cb.node != nullptr) {
...@@ -115,7 +115,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -115,7 +115,7 @@ class PipelineExtractor: public IRVisitor {
cb.node = std::make_shared<ChannelBlockNode>(); cb.node = std::make_shared<ChannelBlockNode>();
cb.node->channel = ch; cb.node->channel = ch;
} }
if (op->type_key == attr::channel_read_scope) { if (op->attr_key == attr::channel_read_scope) {
CHECK_EQ(cb.read_ref_count, 0) CHECK_EQ(cb.read_ref_count, 0)
<< "One channel can only be read from one consumer"; << "One channel can only be read from one consumer";
++cb.read_ref_count; ++cb.read_ref_count;
...@@ -173,7 +173,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -173,7 +173,7 @@ class PipelineExtractor: public IRVisitor {
for (const auto& e : trigger_) { for (const auto& e : trigger_) {
const AttrStmt* attr = e.second; const AttrStmt* attr = e.second;
Channel ch; Channel ch;
if (attr->type_key == attr::pipeline_stage_scope) { if (attr->attr_key == attr::pipeline_stage_scope) {
ch = arg_write; ch = arg_write;
if (!ch.defined()) continue; if (!ch.defined()) continue;
} else { } else {
...@@ -195,10 +195,10 @@ class PipelineExtractor: public IRVisitor { ...@@ -195,10 +195,10 @@ class PipelineExtractor: public IRVisitor {
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size()); trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size. // Grab the advance constant size.
int trigger_size; int trigger_size;
if (attr->type_key == attr::pipeline_stage_scope) { if (attr->attr_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back( cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0)); ControlSignalNode::make(kComputeFinish, 0));
} else if (attr->type_key == attr::channel_read_advance) { } else if (attr->attr_key == attr::channel_read_advance) {
CHECK(arith::GetConstInt(attr->value, &trigger_size)) CHECK(arith::GetConstInt(attr->value, &trigger_size))
<< "Only support constant advance size"; << "Only support constant advance size";
cb.node->ctrl_signals.push_back( cb.node->ctrl_signals.push_back(
......
...@@ -200,7 +200,7 @@ class VTInjector : public IRMutator { ...@@ -200,7 +200,7 @@ class VTInjector : public IRMutator {
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return s;
} else { } else {
return AttrStmt::make(op->node, op->type_key, value, body); return AttrStmt::make(op->node, op->attr_key, value, body);
} }
} }
} }
...@@ -388,7 +388,7 @@ class VirtualThreadInjector : public IRMutator { ...@@ -388,7 +388,7 @@ class VirtualThreadInjector : public IRMutator {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>(); op = stmt.as<AttrStmt>();
if (op->type_key == attr::virtual_thread) { if (op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
int nthread = static_cast<int>(op->value.as<IntImm>()->value); int nthread = static_cast<int>(op->value.as<IntImm>()->value);
VarTouchedAnalysis vs; VarTouchedAnalysis vs;
......
...@@ -68,7 +68,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { ...@@ -68,7 +68,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return s;
} else { } else {
return AttrStmt::make(op->node, op->type_key, value, body); return AttrStmt::make(op->node, op->attr_key, value, body);
} }
} }
......
...@@ -25,9 +25,9 @@ class AllocateLifter : public IRMutator { ...@@ -25,9 +25,9 @@ class AllocateLifter : public IRMutator {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->type_key != attr::virtual_thread) CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before LiftStorageAlloc"; << "InjectVirtualThread before LiftStorageAlloc";
if (op->type_key == attr::storage_scope) { if (op->attr_key == attr::storage_scope) {
StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value); StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value);
allocs_[sc].emplace_back( allocs_[sc].emplace_back(
AttrStmt::make( AttrStmt::make(
...@@ -35,7 +35,7 @@ class AllocateLifter : public IRMutator { ...@@ -35,7 +35,7 @@ class AllocateLifter : public IRMutator {
op->value, Evaluate::make(0))); op->value, Evaluate::make(0)));
storage_scope_[op->node.get()] = sc; storage_scope_[op->node.get()] = sc;
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->type_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag); ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts); curr_thread_scope_.push_back(ts);
...@@ -55,7 +55,7 @@ class AllocateLifter : public IRMutator { ...@@ -55,7 +55,7 @@ class AllocateLifter : public IRMutator {
Stmt body = MergeNest(vec, op->body); Stmt body = MergeNest(vec, op->body);
vec.clear(); vec.clear();
return AttrStmt::make( return AttrStmt::make(
op->node, op->type_key, op->value, body); op->node, op->attr_key, op->value, body);
} }
} }
return stmt; return stmt;
......
...@@ -20,12 +20,12 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -20,12 +20,12 @@ class ThreadAllreduceBuilder : public IRMutator {
: warp_size_(warp_size) {} : warp_size_(warp_size) {}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op); thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = IRMutator::Mutate_(op, s);
thread_extents_.pop_back(); thread_extents_.pop_back();
return ret; return ret;
} else if (op->type_key == attr::storage_scope) { } else if (op->attr_key == attr::storage_scope) {
Stmt ret = IRMutator::Mutate_(op, s); Stmt ret = IRMutator::Mutate_(op, s);
op = ret.as<AttrStmt>(); op = ret.as<AttrStmt>();
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
......
...@@ -107,14 +107,14 @@ class ChannelAccessRewriter : public IRMutator { ...@@ -107,14 +107,14 @@ class ChannelAccessRewriter : public IRMutator {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret; Stmt ret;
const AttrStmt* adv = op->body.as<AttrStmt>(); const AttrStmt* adv = op->body.as<AttrStmt>();
if ((op->type_key == ir::attr::channel_read_scope && if ((op->attr_key == ir::attr::channel_read_scope &&
adv && adv->type_key == ir::attr::channel_read_advance) || adv && adv->attr_key == ir::attr::channel_read_advance) ||
(op->type_key == ir::attr::channel_write_scope && (op->attr_key == ir::attr::channel_write_scope &&
adv && adv->type_key == ir::attr::channel_write_advance)) { adv && adv->attr_key == ir::attr::channel_write_advance)) {
RewriteEntry e; RewriteEntry e;
e.window = op; e.window = op;
e.advance = adv; e.advance = adv;
e.read_access = op->type_key == ir::attr::channel_read_scope; e.read_access = op->attr_key == ir::attr::channel_read_scope;
tasks_.push_back(e); tasks_.push_back(e);
ret = IRMutator::Mutate_(op, s); ret = IRMutator::Mutate_(op, s);
if (tasks_.back().rewrite_success) { if (tasks_.back().rewrite_success) {
......
...@@ -18,7 +18,7 @@ namespace ir { ...@@ -18,7 +18,7 @@ namespace ir {
class IRUseDefAnalysis : public IRMutator { class IRUseDefAnalysis : public IRMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times // thread_extent can appear multiple times
...@@ -35,9 +35,9 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -35,9 +35,9 @@ class IRUseDefAnalysis : public IRMutator {
} }
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s; if (value.same_as(value) && body.same_as(body)) return s;
return AttrStmt::make(op->node, op->type_key, value, body); return AttrStmt::make(op->node, op->attr_key, value, body);
} else if (op->type_key == attr::channel_write_scope || } else if (op->attr_key == attr::channel_write_scope ||
op->type_key == attr::channel_read_scope) { op->attr_key == attr::channel_read_scope) {
Channel ch(op->node.node_); Channel ch(op->node.node_);
if (!use_count_.count(ch->handle_var.get())) { if (!use_count_.count(ch->handle_var.get())) {
this->HandleDef(ch->handle_var.get()); this->HandleDef(ch->handle_var.get());
...@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator {
class HostDeviceSplitter : public IRMutator { class HostDeviceSplitter : public IRMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->type_key == attr::pipeline_exec_scope) { op->attr_key == attr::pipeline_exec_scope) {
return SplitDeviceFunc(s); return SplitDeviceFunc(s);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
......
...@@ -77,7 +77,7 @@ class MarkChannelAccess : public IRMutator { ...@@ -77,7 +77,7 @@ class MarkChannelAccess : public IRMutator {
} }
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == ir::attr::storage_scope) { if (op->attr_key == ir::attr::storage_scope) {
Var buf_var(op->node.node_); Var buf_var(op->node.node_);
if (cmap_.count(buf_var.get())) return Mutate(op->body); if (cmap_.count(buf_var.get())) return Mutate(op->body);
} }
...@@ -223,7 +223,7 @@ class StageSplitter : public IRMutator { ...@@ -223,7 +223,7 @@ class StageSplitter : public IRMutator {
nest.emplace_back(IfThenElse::make(op->condition, no_op)); nest.emplace_back(IfThenElse::make(op->condition, no_op));
} else if (const AttrStmt* op = s.as<AttrStmt>()) { } else if (const AttrStmt* op = s.as<AttrStmt>()) {
nest.emplace_back(AttrStmt::make( nest.emplace_back(AttrStmt::make(
op->node, op->type_key, op->value, no_op)); op->node, op->attr_key, op->value, no_op));
} else if (s.as<ProducerConsumer>()) { } else if (s.as<ProducerConsumer>()) {
} else if (s.as<Block>()) { } else if (s.as<Block>()) {
} else if (const Allocate* op = s.as<Allocate>()) { } else if (const Allocate* op = s.as<Allocate>()) {
...@@ -266,7 +266,7 @@ class PipelineSplitter : public IRMutator { ...@@ -266,7 +266,7 @@ class PipelineSplitter : public IRMutator {
: split_load_(split_load) {} : split_load_(split_load) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == ir::attr::pipeline_exec_scope) { if (op->attr_key == ir::attr::pipeline_exec_scope) {
CHECK_LE(env_.size(), 1U); CHECK_LE(env_.size(), 1U);
const ProducerConsumer* env = nullptr; const ProducerConsumer* env = nullptr;
if (env_.size() == 1) { if (env_.size() == 1) {
...@@ -276,7 +276,7 @@ class PipelineSplitter : public IRMutator { ...@@ -276,7 +276,7 @@ class PipelineSplitter : public IRMutator {
op->body, env); op->body, env);
if (body.same_as(op->body)) return s; if (body.same_as(op->body)) return s;
return AttrStmt::make( return AttrStmt::make(
op->node, op->type_key, op->value, body); op->node, op->attr_key, op->value, body);
} else { } else {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
......
...@@ -40,17 +40,17 @@ class StorageFlattener : public IRMutator { ...@@ -40,17 +40,17 @@ class StorageFlattener : public IRMutator {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::realize_scope) { if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value; storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->type_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
ThreadScope ts = ThreadScope::make(iv->thread_tag); ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts); curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back(); curr_thread_scope_.pop_back();
return stmt; return stmt;
} else if (op->type_key == attr::extern_op_scope) { } else if (op->attr_key == attr::extern_op_scope) {
return HandleExternOp(op); return HandleExternOp(op);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
......
...@@ -57,7 +57,7 @@ class StorageSyncPlanner : public IRVisitor { ...@@ -57,7 +57,7 @@ class StorageSyncPlanner : public IRVisitor {
allow_load_ = false; allow_load_ = false;
} }
void Visit_(const AttrStmt* op) final { void Visit_(const AttrStmt* op) final {
if (op->type_key == "storage_scope") { if (op->attr_key == "storage_scope") {
const Variable* buf = op->node.as<Variable>(); const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] = storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value); StorageScope::make(op->value.as<StringImm>()->value);
......
...@@ -55,7 +55,7 @@ class InjectAttach : public IRMutator { ...@@ -55,7 +55,7 @@ class InjectAttach : public IRMutator {
stmt = IRMutator::Mutate(stmt); stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr && if (op != nullptr &&
op->type_key == attr::loop_scope) { op->attr_key == attr::loop_scope) {
if (attach_spec_->attach_type == kScope && if (attach_spec_->attach_type == kScope &&
op->node == attach_spec_->attach_ivar) { op->node == attach_spec_->attach_ivar) {
CHECK(!found_attach) CHECK(!found_attach)
...@@ -63,7 +63,7 @@ class InjectAttach : public IRMutator { ...@@ -63,7 +63,7 @@ class InjectAttach : public IRMutator {
<< " in multiple places in the IR"; << " in multiple places in the IR";
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->type_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body)); MakePipeline(stage_, dom_map_, op->body));
} }
} }
...@@ -97,12 +97,12 @@ class InjectScanStep : public IRMutator { ...@@ -97,12 +97,12 @@ class InjectScanStep : public IRMutator {
// update // update
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr && if (op != nullptr &&
((op->type_key == attr::scan_update_scope && !is_init_) || ((op->attr_key == attr::scan_update_scope && !is_init_) ||
(op->type_key == attr::scan_init_scope && is_init_))) { (op->attr_key == attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) { if (op->node.same_as(scan_op_)) {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->type_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body)); MakePipeline(stage_, dom_map_, op->body));
} }
} }
...@@ -150,20 +150,20 @@ class SchedulePostProc : public IRMutator { ...@@ -150,20 +150,20 @@ class SchedulePostProc : public IRMutator {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope || if (op->attr_key == attr::loop_scope ||
op->type_key == attr::scan_init_scope) { op->attr_key == attr::scan_init_scope) {
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) { } else if (op->attr_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>(); const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan); CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value; var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->type_key == ir::attr::realize_scope) { } else if (op->attr_key == ir::attr::realize_scope) {
auto it = replace_op_.find(op->node.get()); auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) { if (it != replace_op_.end()) {
if (it->second.defined()) { if (it->second.defined()) {
Stmt ret = AttrStmt::make( Stmt ret = AttrStmt::make(
it->second, op->type_key, op->value, op->body); it->second, op->attr_key, op->value, op->body);
return this->Mutate(ret); return this->Mutate(ret);
} else { } else {
return this->Mutate(op->body); return this->Mutate(op->body);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment