lower_intrin.cc 4.71 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 *  Copyright (c) 2017 by Contributors
 *  Lower intrinsic calls to device specific ir when possible.
 * \file lower_intrin.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include <unordered_set>
30
#include "ir_util.h"
31 32 33 34 35 36 37

namespace tvm {
namespace ir {

class IntrinInjecter : public IRMutator {
 public:
  explicit IntrinInjecter(std::string target) {
38 39 40 41
    std::istringstream is(target);
    std::string starget;
    is >> starget;
    patterns_.push_back("tvm.intrin.rule." + starget + ".");
42
    patterns_.push_back("tvm.intrin.rule.default.");
43
    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
44 45 46 47 48
  }

  Expr Mutate_(const Call* op, const Expr& e) final {
    if (op->call_type == Call::Intrinsic ||
        op->call_type == Call::PureIntrinsic) {
49 50
      Expr r = ApplyPattern(op->name, e);
      if (r.defined()) return r;
51 52 53 54
    }
    return IRMutator::Mutate_(op, e);
  }

55 56
  Expr Mutate_(const Add* op, const Expr& e) final {
    if (const Mul* mb = op->b.as<Mul>()) {
57
      return MakeFMA(mb->a, mb->b, op->a, op, e);
58
    } else if (const Mul* ma = op->a.as<Mul>()) {
59 60 61 62 63 64 65 66 67 68 69 70 71
      return MakeFMA(ma->a, ma->b, op->b, op, e);
    }
    return IRMutator::Mutate_(op, e);
  }

 private:
  Expr SwapBroadcastCast(const Expr& e) {
    // Try to change broadcast(cast(x)) to cast(broadcast(x))
    // For some targets, LLVM will generate more efficient FMA
    // instruction with the latter. For example, vmla vs. vmlal
    // on ARM.
    if (const Broadcast* bcast = e.as<Broadcast>()) {
      if (const Cast* cast = bcast->value.as<Cast>()) {
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        auto should_swap = [&]() {
          // Maintain behaviour (int8 -> int16, fp16 -> fp32).
          if (cast->type.bits() == cast->value.type().bits() * 2) {
            return true;
          }
          // Check both operands are integer-like.
          if (!cast->type.is_uint() && !cast->type.is_int()) {
            return false;
          }
          if (!cast->value.type().is_uint() && !cast->value.type().is_int()) {
            return false;
          }
          // If both are integer-like, swap if we have a widening cast.
          return cast->type.bits() > cast->value.type().bits();
        };

        if (should_swap()) {
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
          Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
          return Cast::make(bcast->type, new_bcast);
        }
      }
    }
    return e;
  }

  Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
               const Add* op, const Expr& e) {
    // emit fma instruction: a * b + c
    Expr lhs = SwapBroadcastCast(a);
    Expr rhs = SwapBroadcastCast(b);

    if (fma_ != nullptr && op->type.is_float()) {
104
      Expr r = (*fma_)(Call::make(
105
          op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
106
      if (r.defined()) return this->Mutate(r);
107 108 109 110 111
    } else {
      if (!lhs.same_as(a) || !rhs.same_as(b)) {
        Expr mul = this->Mutate(Mul::make(lhs, rhs));
        return Add::make(mul, this->Mutate(c));
      }
112 113 114 115
    }
    return IRMutator::Mutate_(op, e);
  }

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
  Expr ApplyPattern(const std::string& name, const Expr& e) {
    for (size_t i = 0; i < patterns_.size(); ++i) {
      std::string& p = patterns_[i];
      size_t psize = p.length();
      p.resize(psize + name.length());
      name.copy(&p[0] + psize, name.length());
      const runtime::PackedFunc* f = runtime::Registry::Get(p);
      p.resize(psize);
      // if pattern exists.
      if (f != nullptr) {
        Expr r = (*f)(e);
        CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
        if (!r.same_as(e)) {
          return this->Mutate(r);
        }
      }
    }
    return Expr();
  }
  // patterns
136
  std::vector<std::string> patterns_;
137
  const PackedFunc* fma_{nullptr};
138 139 140 141
};

LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
142
  auto n = make_node<LoweredFuncNode>(*f.operator->());
143 144 145 146 147 148
  n->body = IntrinInjecter(target).Mutate(n->body);
  return LoweredFunc(n);
}

}  // namespace ir
}  // namespace tvm