Commit f73c461f by Tianqi Chen Committed by GitHub

[BACKEND] Explicitly allow specialization of FMA in llvm (#407)

parent a45d3b01
......@@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);
......
......@@ -16,11 +16,12 @@ namespace ir {
class IntrinInjecter : public IRMutator {
public:
explicit IntrinInjecter(std::string target) {
patterns_.push_back("tvm.intrin.rule." + target + ".");
if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") {
patterns_.push_back("tvm.intrin.rule.llvm.");
}
std::istringstream is(target);
std::string starget;
is >> starget;
patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
}
Expr Mutate_(const Call* op, const Expr& e) final {
......@@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator {
return IRMutator::Mutate_(op, e);
}
Expr Mutate_(const Add* op, const Expr& e) final {
if (fma_ == nullptr || !op->type.is_float()) {
return IRMutator::Mutate_(op, e);
}
if (const Mul* mb = op->b.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
if (r.defined()) return r;
} else if (const Mul* ma = op->a.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
}
private:
Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) {
......@@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator {
}
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
};
LoweredFunc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment