Commit 592a1f65 by ziheng Committed by Tianqi Chen

[CODEGEN] Detect broadcast(cast(x)) pattern in FMA (#551)

* [CODEGEN] Detect broadcast(cast(x)) pattern in FMA

* [CODEGEN] Improve

* [CODEGEN] Fix
parent fde9b570
...@@ -34,22 +34,50 @@ class IntrinInjecter : public IRMutator { ...@@ -34,22 +34,50 @@ class IntrinInjecter : public IRMutator {
} }
Expr Mutate_(const Add* op, const Expr& e) final { Expr Mutate_(const Add* op, const Expr& e) final {
if (fma_ == nullptr || !op->type.is_float()) {
return IRMutator::Mutate_(op, e);
}
if (const Mul* mb = op->b.as<Mul>()) { if (const Mul* mb = op->b.as<Mul>()) {
Expr r = (*fma_)(Call::make( return MakeFMA(mb->a, mb->b, op->a, op, e);
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
if (r.defined()) return this->Mutate(r);
} else if (const Mul* ma = op->a.as<Mul>()) { } else if (const Mul* ma = op->a.as<Mul>()) {
return MakeFMA(ma->a, ma->b, op->b, op, e);
}
return IRMutator::Mutate_(op, e);
}
private:
Expr SwapBroadcastCast(const Expr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x))
// For some targets, LLVM will generate more efficient FMA
// instruction with the latter. For example, vmla vs. vmlal
// on ARM.
if (const Broadcast* bcast = e.as<Broadcast>()) {
if (const Cast* cast = bcast->value.as<Cast>()) {
if (cast->type.bits() == cast->value.type().bits() * 2) {
Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
return Cast::make(bcast->type, new_bcast);
}
}
}
return e;
}
Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
const Add* op, const Expr& e) {
// emit fma instruction: a * b + c
Expr lhs = SwapBroadcastCast(a);
Expr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->type.is_float()) {
Expr r = (*fma_)(Call::make( Expr r = (*fma_)(Call::make(
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic)); op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
if (r.defined()) return this->Mutate(r); if (r.defined()) return this->Mutate(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
Expr mul = this->Mutate(Mul::make(lhs, rhs));
return Add::make(mul, this->Mutate(c));
}
} }
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
private:
Expr ApplyPattern(const std::string& name, const Expr& e) { Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) { for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i]; std::string& p = patterns_[i];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment