Commit 591afad9 by Tianqi Chen Committed by GitHub

[SCHEDULE] Remap the cached bind_scope. (#272)

* [SCHEDULE] Remap the cached bind_scope.

* more fix
parent 2e373de4
......@@ -172,19 +172,15 @@ class TensorIntrinMatcher final : public IRMutator {
Expr Mutate_(const Reduce* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Reduce>();
Array<IterVar> axis = op->axis;
Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
auto it = axis_remap_.find(op->axis[i]);
if (it != axis_remap_.end()) {
axis.Set(i, it->second);
axis.push_back(it->second);
}
}
if (!axis.same_as(op->axis)) {
return Reduce::make(
op->combiner, op->source, axis, op->condition, op->value_index);
} else {
return e;
}
return Reduce::make(
op->combiner, op->source, axis, op->condition, op->value_index);
}
void Init(const ComputeOpNode* self,
......@@ -192,6 +188,7 @@ class TensorIntrinMatcher final : public IRMutator {
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
CHECK(self == stage->op.get());
// input remap.
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size());
......@@ -204,7 +201,8 @@ class TensorIntrinMatcher final : public IRMutator {
e.start = e.region.size() - e.tensor.ndim();
for (size_t i = 0; i < e.start; ++i) {
CHECK(is_one(e.region[i]->extent))
<< "Tensorize: Input dimension mismatch with tensor intrin "
<< "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape
<< ", given region=" << e.region;
}
......@@ -223,6 +221,7 @@ class TensorIntrinMatcher final : public IRMutator {
<< "Tensorize: Output mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->axis.size()
<< ", tensorize-dim=" << self->axis.size();
var_remap_[self->axis[i]->var.get()] = r->min;
}
// Assume we tensorize at regin axis i [min, min + extent)
// The corresponding intrinsic axis is j [0, extent)
......@@ -244,6 +243,7 @@ class TensorIntrinMatcher final : public IRMutator {
<< "Tensorize: Reduction mismatch with tensor intrin "
<< " intrin-dim=" << intrin_compute->reduce_axis.size()
<< ", tensorize-dim=" << self->reduce_axis.size();
var_remap_[self->reduce_axis[i]->var.get()] = r->min;
}
for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
IterVar iv = self->reduce_axis[i];
......@@ -328,7 +328,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data.
Stmt nop = Evaluate::make(0);
std::vector<Stmt> bind_nest;
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size())
<< "Tensorize failed: input size mismatch ";
......@@ -345,7 +345,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
tuple.push_back(r->min);
tuple.push_back(r->extent);
}
bind_nest.emplace_back(AttrStmt::make(
input_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
......@@ -365,7 +365,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor};
bind_nest.emplace_back(AttrStmt::make(
output_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
......@@ -400,11 +400,12 @@ Stmt MakeTensorize(const ComputeOpNode* self,
CHECK_EQ(n.init_predicates.size(), 0U);
CHECK(intrin->body.defined())
<< "Normal store op for intrin " << intrin << " is not defined";
Stmt body = ir::MergeNest(bind_nest, intrin->body);
Stmt body = MergeNest(output_bind_nest, intrin->body);
body = MergeNest(input_bind_nest, body);
body = Substitute(body, vmap);
body = ir::MergeNest(binder.asserts(), body);
body = MergeNest(binder.asserts(), body);
body = Substitute(body, n.main_vmap);
return ir::MergeNest(nest, body);
return MergeNest(nest, body);
} else {
// Need to split reduction
CHECK(intrin->reduce_init.defined())
......@@ -419,14 +420,15 @@ Stmt MakeTensorize(const ComputeOpNode* self,
std::vector<std::vector<Stmt> > init_nest(
n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
Stmt init = MergeNest(bind_nest, intrin->reduce_init);
Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
init = Substitute(init, n.init_vmap);
init = MergeNest(init_nest, init);
// The update
std::vector<std::vector<Stmt> > update_nest(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
Stmt update = MergeNest(bind_nest, intrin->reduce_update);
Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
update = MergeNest(input_bind_nest, update);
update = Substitute(update, vmap);
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
......
......@@ -117,7 +117,6 @@ class StorageFlattener : public IRMutator {
Array<Expr>(), Expr(),
key.GetName(), skey.to_string(),
align, 0);
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
......@@ -239,7 +238,8 @@ class StorageFlattener : public IRMutator {
CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
CHECK(buf_map_.count(key));
CHECK(buf_map_.count(key))
<< "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
const BufferEntry& be = buf_map_.at(key);
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
......
......@@ -181,6 +181,19 @@ class SchedulePostProc : public IRMutator {
return this->Mutate(op->body);
}
}
} else if (op->attr_key == ir::attr::buffer_bind_scope) {
Array<NodeRef> tuple(op->node.node_);
Tensor tensor(tuple[1].node_);
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
return AttrStmt::make(
Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
op->attr_key, op->value, Mutate(op->body));
} else {
return this->Mutate(op->body);
}
}
}
return IRMutator::Mutate_(op, s);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment