Commit 10d9da48 by Salem Derisavi Committed by Tianqi Chen

Consider variable range information during simplification of tensorize expressions (#674)

parent cf81f9f9
......@@ -5,6 +5,7 @@
*/
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_pass.h>
#include "./canonical.h"
#include "./compute_expr.h"
#include "arithmetic/Simplify.h"
......@@ -612,6 +613,7 @@ void Canonical::SetRange(Var v, Range r, int level) {
} // namespace arith
namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::Canonical(vrange).Simplify(stmt);
}
......
......@@ -187,7 +187,8 @@ class TensorIntrinMatcher final : public IRMutator {
const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
CHECK(self == stage->op.get());
// input remap.
Array<Tensor> inputs = self->InputTensors();
......@@ -232,6 +233,7 @@ class TensorIntrinMatcher final : public IRMutator {
Range r = out_dom.at(iv);
var_remap_[iv->var.get()] = target_iv->var + r->min;
axis_remap_[iv] = target_iv;
compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
}
// Remap reduction axis
CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
......@@ -251,6 +253,7 @@ class TensorIntrinMatcher final : public IRMutator {
Range r = out_dom.at(iv);
var_remap_[iv->var.get()] = target_iv->var + r->min;
axis_remap_[iv] = target_iv;
compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
}
}
......@@ -275,9 +278,10 @@ Array<Expr> MatchTensorizeBody(
const Stage& stage,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) {
TensorIntrinMatcher matcher;
matcher.Init(self, stage, out_dom, in_region, intrin);
matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<Expr> ret;
for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr));
......@@ -291,14 +295,16 @@ void VerifyTensorizeBody(
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin);
Map<Var, Range> compute_intrin_iter_space;
Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin,
&compute_intrin_iter_space);
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(body.size(), intrin_compute->body.size())
<< "Tensorize failed: body size mismatch";
for (size_t i = 0; i < body.size(); ++i) {
Expr lhs = CanonicalSimplify(body[i]);
Expr rhs = CanonicalSimplify(intrin_compute->body[i]);
Expr lhs = CanonicalSimplify(body[i], compute_intrin_iter_space);
Expr rhs = CanonicalSimplify(intrin_compute->body[i], compute_intrin_iter_space);
if (lhs.type() != rhs.type()) {
LOG(FATAL)
<< "Failed to match the data type with TensorIntrin "
......@@ -459,11 +465,13 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody")
Map<IterVar, Range> out_dom = args[1];
Map<Tensor, Array<Range> > in_region = args[2];
TensorIntrin intrin = args[3];
Map<Var, Range> vrange;
CHECK(stage->op.as<ComputeOpNode>());
*ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
stage,
as_unordered_map(out_dom),
as_unordered_map(in_region),
intrin);
intrin,
&vrange);
});
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment