Commit 591afad9 by Tianqi Chen Committed by GitHub

[SCHEDULE] Remap the cached bind_scope. (#272)

* [SCHEDULE] Remap the cached bind_scope.

* more fix
parent 2e373de4
...@@ -172,19 +172,15 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -172,19 +172,15 @@ class TensorIntrinMatcher final : public IRMutator {
Expr Mutate_(const Reduce* op, const Expr& e) final { Expr Mutate_(const Reduce* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Reduce>(); op = expr.as<Reduce>();
Array<IterVar> axis = op->axis; Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) { for (size_t i = 0; i < op->axis.size(); ++i) {
auto it = axis_remap_.find(op->axis[i]); auto it = axis_remap_.find(op->axis[i]);
if (it != axis_remap_.end()) { if (it != axis_remap_.end()) {
axis.Set(i, it->second); axis.push_back(it->second);
} }
} }
if (!axis.same_as(op->axis)) { return Reduce::make(
return Reduce::make( op->combiner, op->source, axis, op->condition, op->value_index);
op->combiner, op->source, axis, op->condition, op->value_index);
} else {
return e;
}
} }
void Init(const ComputeOpNode* self, void Init(const ComputeOpNode* self,
...@@ -192,6 +188,7 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -192,6 +188,7 @@ class TensorIntrinMatcher final : public IRMutator {
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) { const TensorIntrin& intrin) {
CHECK(self == stage->op.get());
// input remap. // input remap.
Array<Tensor> inputs = self->InputTensors(); Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size()); CHECK_EQ(inputs.size(), intrin->inputs.size());
...@@ -204,7 +201,8 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -204,7 +201,8 @@ class TensorIntrinMatcher final : public IRMutator {
e.start = e.region.size() - e.tensor.ndim(); e.start = e.region.size() - e.tensor.ndim();
for (size_t i = 0; i < e.start; ++i) { for (size_t i = 0; i < e.start; ++i) {
CHECK(is_one(e.region[i]->extent)) CHECK(is_one(e.region[i]->extent))
<< "Tensorize: Input dimension mismatch with tensor intrin " << "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape << " expected shape=" << e.tensor->shape
<< ", given region=" << e.region; << ", given region=" << e.region;
} }
...@@ -223,6 +221,7 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -223,6 +221,7 @@ class TensorIntrinMatcher final : public IRMutator {
<< "Tensorize: Output mismatch with tensor intrin " << "Tensorize: Output mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->axis.size() << " intrin-dim=" << intrin_compute->axis.size()
<< ", tensorize-dim=" << self->axis.size(); << ", tensorize-dim=" << self->axis.size();
var_remap_[self->axis[i]->var.get()] = r->min;
} }
// Assume we tensorize at regin axis i [min, min + extent) // Assume we tensorize at regin axis i [min, min + extent)
// The corresponding intrinsic axis is j [0, extent) // The corresponding intrinsic axis is j [0, extent)
...@@ -244,6 +243,7 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -244,6 +243,7 @@ class TensorIntrinMatcher final : public IRMutator {
<< "Tensorize: Reduction mismatch with tensor intrin " << "Tensorize: Reduction mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->reduce_axis.size() << " intrin-dim=" << intrin_compute->reduce_axis.size()
<< ", tensorize-dim=" << self->reduce_axis.size(); << ", tensorize-dim=" << self->reduce_axis.size();
var_remap_[self->reduce_axis[i]->var.get()] = r->min;
} }
for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
IterVar iv = self->reduce_axis[i]; IterVar iv = self->reduce_axis[i];
...@@ -328,7 +328,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -328,7 +328,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data. // Start bind data.
Stmt nop = Evaluate::make(0); Stmt nop = Evaluate::make(0);
std::vector<Stmt> bind_nest; std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = self->InputTensors(); Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size()) CHECK_EQ(inputs.size(), intrin->inputs.size())
<< "Tensorize failed: input size mismatch "; << "Tensorize failed: input size mismatch ";
...@@ -345,7 +345,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -345,7 +345,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
tuple.push_back(r->min); tuple.push_back(r->min);
tuple.push_back(r->extent); tuple.push_back(r->extent);
} }
bind_nest.emplace_back(AttrStmt::make( input_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope, bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
} }
...@@ -365,7 +365,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -365,7 +365,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
Tensor tensor = stage->op.output(i - intrin->inputs.size()); Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i]; Buffer buffer = intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor}; Array<NodeRef> bind_spec{buffer, tensor};
bind_nest.emplace_back(AttrStmt::make( output_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope, bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
} }
...@@ -400,11 +400,12 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -400,11 +400,12 @@ Stmt MakeTensorize(const ComputeOpNode* self,
CHECK_EQ(n.init_predicates.size(), 0U); CHECK_EQ(n.init_predicates.size(), 0U);
CHECK(intrin->body.defined()) CHECK(intrin->body.defined())
<< "Normal store op for intrin " << intrin << " is not defined"; << "Normal store op for intrin " << intrin << " is not defined";
Stmt body = ir::MergeNest(bind_nest, intrin->body); Stmt body = MergeNest(output_bind_nest, intrin->body);
body = MergeNest(input_bind_nest, body);
body = Substitute(body, vmap); body = Substitute(body, vmap);
body = ir::MergeNest(binder.asserts(), body); body = MergeNest(binder.asserts(), body);
body = Substitute(body, n.main_vmap); body = Substitute(body, n.main_vmap);
return ir::MergeNest(nest, body); return MergeNest(nest, body);
} else { } else {
// Need to split reduction // Need to split reduction
CHECK(intrin->reduce_init.defined()) CHECK(intrin->reduce_init.defined())
...@@ -419,14 +420,15 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -419,14 +420,15 @@ Stmt MakeTensorize(const ComputeOpNode* self,
std::vector<std::vector<Stmt> > init_nest( std::vector<std::vector<Stmt> > init_nest(
n.init_nest.begin(), n.init_nest.begin() + tloc + 1); n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
Stmt init = MergeNest(bind_nest, intrin->reduce_init); Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
init = Substitute(init, n.init_vmap); init = Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init); init = MergeNest(init_nest, init);
// The update // The update
std::vector<std::vector<Stmt> > update_nest( std::vector<std::vector<Stmt> > update_nest(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
Stmt update = MergeNest(bind_nest, intrin->reduce_update); Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
update = MergeNest(input_bind_nest, update);
update = Substitute(update, vmap); update = Substitute(update, vmap);
update = MergeNest(binder.asserts(), update); update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap); update = Substitute(update, n.main_vmap);
......
...@@ -117,7 +117,6 @@ class StorageFlattener : public IRMutator { ...@@ -117,7 +117,6 @@ class StorageFlattener : public IRMutator {
Array<Expr>(), Expr(), Array<Expr>(), Expr(),
key.GetName(), skey.to_string(), key.GetName(), skey.to_string(),
align, 0); align, 0);
buf_map_[key] = e; buf_map_[key] = e;
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
buf_map_[key].released = true; buf_map_[key].released = true;
...@@ -239,7 +238,8 @@ class StorageFlattener : public IRMutator { ...@@ -239,7 +238,8 @@ class StorageFlattener : public IRMutator {
CHECK(buffer && tensor); CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index}; TensorKey key{tensor->op, tensor->value_index};
CHECK(buf_map_.count(key)); CHECK(buf_map_.count(key))
<< "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
const BufferEntry& be = buf_map_.at(key); const BufferEntry& be = buf_map_.at(key);
CHECK(!be.released); CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
......
...@@ -181,6 +181,19 @@ class SchedulePostProc : public IRMutator { ...@@ -181,6 +181,19 @@ class SchedulePostProc : public IRMutator {
return this->Mutate(op->body); return this->Mutate(op->body);
} }
} }
} else if (op->attr_key == ir::attr::buffer_bind_scope) {
Array<NodeRef> tuple(op->node.node_);
Tensor tensor(tuple[1].node_);
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
return AttrStmt::make(
Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
op->attr_key, op->value, Mutate(op->body));
} else {
return this->Mutate(op->body);
}
}
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment