/*! * Copyright (c) 2017 by Contributors * Lower intrinsic calls to device specific ir when possible. * \file lower_intrin.cc */ #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> #include <tvm/api_registry.h> #include <unordered_set> #include "./ir_util.h" namespace tvm { namespace ir { class IntrinInjecter : public IRMutator { public: explicit IntrinInjecter(std::string target) { std::istringstream is(target); std::string starget; is >> starget; patterns_.push_back("tvm.intrin.rule." + starget + "."); patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); } Expr Mutate_(const Call* op, const Expr& e) final { if (op->call_type == Call::Intrinsic || op->call_type == Call::PureIntrinsic) { Expr r = ApplyPattern(op->name, e); if (r.defined()) return r; } return IRMutator::Mutate_(op, e); } Expr Mutate_(const Add* op, const Expr& e) final { if (const Mul* mb = op->b.as<Mul>()) { return MakeFMA(mb->a, mb->b, op->a, op, e); } else if (const Mul* ma = op->a.as<Mul>()) { return MakeFMA(ma->a, ma->b, op->b, op, e); } return IRMutator::Mutate_(op, e); } private: Expr SwapBroadcastCast(const Expr& e) { // Try to change broadcast(cast(x)) to cast(broadcast(x)) // For some targets, LLVM will generate more efficient FMA // instruction with the latter. For example, vmla vs. vmlal // on ARM. if (const Broadcast* bcast = e.as<Broadcast>()) { if (const Cast* cast = bcast->value.as<Cast>()) { if (cast->type.bits() == cast->value.type().bits() * 2) { Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); return Cast::make(bcast->type, new_bcast); } } } return e; } Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, const Add* op, const Expr& e) { // emit fma instruction: a * b + c Expr lhs = SwapBroadcastCast(a); Expr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->type.is_float()) { Expr r = (*fma_)(Call::make( op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); if (r.defined()) return this->Mutate(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { Expr mul = this->Mutate(Mul::make(lhs, rhs)); return Add::make(mul, this->Mutate(c)); } } return IRMutator::Mutate_(op, e); } Expr ApplyPattern(const std::string& name, const Expr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; size_t psize = p.length(); p.resize(psize + name.length()); name.copy(&p[0] + psize, name.length()); const runtime::PackedFunc* f = runtime::Registry::Get(p); p.resize(psize); // if pattern exists. if (f != nullptr) { Expr r = (*f)(e); CHECK(r.defined()) << "intrinsic rule must always return valid Expr"; if (!r.same_as(e)) { return this->Mutate(r); } } } return Expr(); } // patterns std::vector<std::string> patterns_; const PackedFunc* fma_{nullptr}; }; LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = IntrinInjecter(target).Mutate(n->body); return LoweredFunc(n); } } // namespace ir } // namespace tvm