Commit f73c461f by Tianqi Chen Committed by GitHub

[BACKEND] Explicitly allow specialization of FMA in llvm (#407)

parent a45d3b01
...@@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);
......
...@@ -16,11 +16,12 @@ namespace ir { ...@@ -16,11 +16,12 @@ namespace ir {
class IntrinInjecter : public IRMutator { class IntrinInjecter : public IRMutator {
public: public:
explicit IntrinInjecter(std::string target) { explicit IntrinInjecter(std::string target) {
patterns_.push_back("tvm.intrin.rule." + target + "."); std::istringstream is(target);
if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") { std::string starget;
patterns_.push_back("tvm.intrin.rule.llvm."); is >> starget;
} patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default."); patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
} }
Expr Mutate_(const Call* op, const Expr& e) final { Expr Mutate_(const Call* op, const Expr& e) final {
...@@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator { ...@@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator {
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
Expr Mutate_(const Add* op, const Expr& e) final {
if (fma_ == nullptr || !op->type.is_float()) {
return IRMutator::Mutate_(op, e);
}
if (const Mul* mb = op->b.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
if (r.defined()) return r;
} else if (const Mul* ma = op->a.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
}
private: private:
Expr ApplyPattern(const std::string& name, const Expr& e) { Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) { for (size_t i = 0; i < patterns_.size(); ++i) {
...@@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator { ...@@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator {
} }
// patterns // patterns
std::vector<std::string> patterns_; std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
}; };
LoweredFunc LoweredFunc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment