Commit 10d9da48 by Salem Derisavi Committed by Tianqi Chen

Consider variable range information during simplification of tensorize expressions (#674)

parent cf81f9f9
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir_pass.h>
#include "./canonical.h" #include "./canonical.h"
#include "./compute_expr.h" #include "./compute_expr.h"
#include "arithmetic/Simplify.h" #include "arithmetic/Simplify.h"
...@@ -612,6 +613,7 @@ void Canonical::SetRange(Var v, Range r, int level) { ...@@ -612,6 +613,7 @@ void Canonical::SetRange(Var v, Range r, int level) {
} // namespace arith } // namespace arith
namespace ir { namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) { Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::Canonical(vrange).Simplify(stmt); return arith::Canonical(vrange).Simplify(stmt);
} }
......
...@@ -187,7 +187,8 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -187,7 +187,8 @@ class TensorIntrinMatcher final : public IRMutator {
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) { const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
CHECK(self == stage->op.get()); CHECK(self == stage->op.get());
// input remap. // input remap.
Array<Tensor> inputs = self->InputTensors(); Array<Tensor> inputs = self->InputTensors();
...@@ -232,6 +233,7 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -232,6 +233,7 @@ class TensorIntrinMatcher final : public IRMutator {
Range r = out_dom.at(iv); Range r = out_dom.at(iv);
var_remap_[iv->var.get()] = target_iv->var + r->min; var_remap_[iv->var.get()] = target_iv->var + r->min;
axis_remap_[iv] = target_iv; axis_remap_[iv] = target_iv;
compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
} }
// Remap reduction axis // Remap reduction axis
CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
...@@ -251,6 +253,7 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -251,6 +253,7 @@ class TensorIntrinMatcher final : public IRMutator {
Range r = out_dom.at(iv); Range r = out_dom.at(iv);
var_remap_[iv->var.get()] = target_iv->var + r->min; var_remap_[iv->var.get()] = target_iv->var + r->min;
axis_remap_[iv] = target_iv; axis_remap_[iv] = target_iv;
compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
} }
} }
...@@ -275,9 +278,10 @@ Array<Expr> MatchTensorizeBody( ...@@ -275,9 +278,10 @@ Array<Expr> MatchTensorizeBody(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) { const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
TensorIntrinMatcher matcher; TensorIntrinMatcher matcher;
matcher.Init(self, stage, out_dom, in_region, intrin); matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<Expr> ret; Array<Expr> ret;
for (Expr expr : self->body) { for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr)); ret.push_back(matcher.Mutate(expr));
...@@ -291,14 +295,16 @@ void VerifyTensorizeBody( ...@@ -291,14 +295,16 @@ void VerifyTensorizeBody(
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) { const TensorIntrin& intrin) {
Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin); Map<Var, Range> compute_intrin_iter_space;
Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin,
&compute_intrin_iter_space);
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(body.size(), intrin_compute->body.size()) CHECK_EQ(body.size(), intrin_compute->body.size())
<< "Tensorize failed: body size mismatch"; << "Tensorize failed: body size mismatch";
for (size_t i = 0; i < body.size(); ++i) { for (size_t i = 0; i < body.size(); ++i) {
Expr lhs = CanonicalSimplify(body[i]); Expr lhs = CanonicalSimplify(body[i], compute_intrin_iter_space);
Expr rhs = CanonicalSimplify(intrin_compute->body[i]); Expr rhs = CanonicalSimplify(intrin_compute->body[i], compute_intrin_iter_space);
if (lhs.type() != rhs.type()) { if (lhs.type() != rhs.type()) {
LOG(FATAL) LOG(FATAL)
<< "Failed to match the data type with TensorIntrin " << "Failed to match the data type with TensorIntrin "
...@@ -459,11 +465,13 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody") ...@@ -459,11 +465,13 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody")
Map<IterVar, Range> out_dom = args[1]; Map<IterVar, Range> out_dom = args[1];
Map<Tensor, Array<Range> > in_region = args[2]; Map<Tensor, Array<Range> > in_region = args[2];
TensorIntrin intrin = args[3]; TensorIntrin intrin = args[3];
Map<Var, Range> vrange;
CHECK(stage->op.as<ComputeOpNode>()); CHECK(stage->op.as<ComputeOpNode>());
*ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
stage, stage,
as_unordered_map(out_dom), as_unordered_map(out_dom),
as_unordered_map(in_region), as_unordered_map(in_region),
intrin); intrin,
&vrange);
}); });
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment