/*!
 *  Copyright (c) 2017 by Contributors
 *  Lower intrinsic calls to device specific ir when possible.
 * \file lower_intrin.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include "./ir_util.h"

namespace tvm {
namespace ir {

class IntrinInjecter : public IRMutator {
 public:
  explicit IntrinInjecter(std::string target) {
    std::istringstream is(target);
    std::string starget;
    is >> starget;
    patterns_.push_back("tvm.intrin.rule." + starget + ".");
    patterns_.push_back("tvm.intrin.rule.default.");
    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
  }

  Expr Mutate_(const Call* op, const Expr& e) final {
    if (op->call_type == Call::Intrinsic ||
        op->call_type == Call::PureIntrinsic) {
      Expr r = ApplyPattern(op->name, e);
      if (r.defined()) return r;
    }
    return IRMutator::Mutate_(op, e);
  }

  Expr Mutate_(const Add* op, const Expr& e) final {
    if (const Mul* mb = op->b.as<Mul>()) {
      return MakeFMA(mb->a, mb->b, op->a, op, e);
    } else if (const Mul* ma = op->a.as<Mul>()) {
      return MakeFMA(ma->a, ma->b, op->b, op, e);
    }
    return IRMutator::Mutate_(op, e);
  }

 private:
  Expr SwapBroadcastCast(const Expr& e) {
    // Try to change broadcast(cast(x)) to cast(broadcast(x))
    // For some targets, LLVM will generate more efficient FMA
    // instruction with the latter. For example, vmla vs. vmlal
    // on ARM.
    if (const Broadcast* bcast = e.as<Broadcast>()) {
      if (const Cast* cast = bcast->value.as<Cast>()) {
        if (cast->type.bits() == cast->value.type().bits() * 2) {
          Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
          return Cast::make(bcast->type, new_bcast);
        }
      }
    }
    return e;
  }

  Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
               const Add* op, const Expr& e) {
    // emit fma instruction: a * b + c
    Expr lhs = SwapBroadcastCast(a);
    Expr rhs = SwapBroadcastCast(b);

    if (fma_ != nullptr && op->type.is_float()) {
      Expr r = (*fma_)(Call::make(
          op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
      if (r.defined()) return this->Mutate(r);
    } else {
      if (!lhs.same_as(a) || !rhs.same_as(b)) {
        Expr mul = this->Mutate(Mul::make(lhs, rhs));
        return Add::make(mul, this->Mutate(c));
      }
    }
    return IRMutator::Mutate_(op, e);
  }

  Expr ApplyPattern(const std::string& name, const Expr& e) {
    for (size_t i = 0; i < patterns_.size(); ++i) {
      std::string& p = patterns_[i];
      size_t psize = p.length();
      p.resize(psize + name.length());
      name.copy(&p[0] + psize, name.length());
      const runtime::PackedFunc* f = runtime::Registry::Get(p);
      p.resize(psize);
      // if pattern exists.
      if (f != nullptr) {
        Expr r = (*f)(e);
        CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
        if (!r.same_as(e)) {
          return this->Mutate(r);
        }
      }
    }
    return Expr();
  }
  // patterns
  std::vector<std::string> patterns_;
  const PackedFunc* fma_{nullptr};
};

LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = IntrinInjecter(target).Mutate(n->body);
  return LoweredFunc(n);
}

}  // namespace ir
}  // namespace tvm