Commit f06429dd by Salem Derisavi Committed by Tianqi Chen

During tensorize, call Simplify on algorithm and intrinsic definitions before…

During tensorize, call Simplify on algorithm and intrinsic definitions before CanonicalSimplify. This will prevent a number of false tensorize mismatches. (#718)

thanks, this we can use this solution for now 
parent a31f6158
......@@ -648,6 +648,24 @@ T Simplify_(T a, Map<Var, Range> vrange) {
Expr Simplify(Expr a, Map<Var, Range> vrange) {
// We should not pass an expression having a non-HalideIR op to
// Halide::Internal::simplify. Reduce op is the only such op at this time
// and it only appears as the top op in an expression. So we strip it
// first and send the sub-expressions to the simplifier.
if (const Reduce* r = a.as<Reduce>()) {
Array<Expr> new_source;
for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange));
}
Expr new_condition = Simplify_(r->condition, vrange);
if (r->source.same_as(new_source) &&
r->condition.same_as(new_condition)) {
return a;
} else {
return Reduce::make(
r->combiner, new_source, r->axis, new_condition, r->value_index);
}
}
return Simplify_(a, vrange);
}
......
......@@ -303,8 +303,10 @@ void VerifyTensorizeBody(
CHECK_EQ(body.size(), intrin_compute->body.size())
<< "Tensorize failed: body size mismatch";
for (size_t i = 0; i < body.size(); ++i) {
Expr lhs = CanonicalSimplify(body[i], compute_intrin_iter_space);
Expr rhs = CanonicalSimplify(intrin_compute->body[i], compute_intrin_iter_space);
Expr lhs = Simplify(body[i], compute_intrin_iter_space);
lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
if (lhs.type() != rhs.type()) {
LOG(FATAL)
<< "Failed to match the data type with TensorIntrin "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment