intrin_rule_llvm.cc 3.42 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file intrin_rule_llvm.cc
 */
#ifdef TVM_LLVM_VERSION

25
#include "intrin_rule_llvm.h"
26 27 28 29 30

namespace tvm {
namespace codegen {
namespace llvm {

31 32 33
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);

34
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
35
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
36

37
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
38
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
39

40
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
41
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
42

ziheng committed
43
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
44
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
ziheng committed
45

46 47 48 49 50 51
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

52 53 54
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

55 56 57
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

58 59 60
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

61 62 63
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

64 65 66 67 68 69
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
  Expr e = targs[0];
  const ir::Call* call = e.as<ir::Call>();
  CHECK(call != nullptr);
  const Expr& x = call->args[0];
70 71 72
  Expr one = make_const(x.dtype(), 1);
  Expr two = make_const(x.dtype(), 2);
  Expr neg_two = make_const(x.dtype(), -2);
73 74

  Expr exp_neg2x = ir::Call::make(
75
      x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic);
76
  Expr exp_pos2x = ir::Call::make(
77
      x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic);
78 79 80 81

  Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
  Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
  *rv = ir::Select::make(
82
      x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
83 84
});

85
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
86
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>);
87

88 89 90
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);

91 92 93 94 95 96
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

97 98 99 100 101
}  // namespace llvm
}  // namespace codegen
}  // namespace tvm

#endif  // LLVM_VERSION