intrin_rule_cuda.cc 3.22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2017 by Contributors
 * \file intrin_rule_cuda.cc
 * \brief CUDA intrinsic rules.
 */
25
#include "intrin_rule.h"
26 27 28 29 30

namespace tvm {
namespace codegen {
namespace intrin {
// Add float suffix to the intrinsics, CUDA fast math.
31
struct CUDAMath {
32 33 34 35 36
  std::string operator()(Type t, std::string name) const {
    if (t.lanes() == 1) {
      if (t.is_float()) {
        switch (t.bits()) {
          case 64: return name;
37
          case 32: return name + 'f';
38 39 40 41 42 43 44 45 46
          case 16: return 'h' + name;
          default: return "";
        }
      }
    }
    return "";
  }
};

47 48 49 50 51 52 53 54 55 56 57
struct CUDAFastMath : public CUDAMath {
  std::string operator()(Type t, std::string name) const {
    if (t.lanes() == 1 && t.is_float() && t.bits() == 32) {
      return "__" + name + 'f';
    } else {
      return CUDAMath::operator()(t, name);
    }
    return "";
  }
};

58 59 60 61 62 63 64 65 66 67 68 69 70
struct CUDAPopcount {
  std::string operator()(Type t, std::string name) const {
    if (t.lanes() == 1 && t.is_uint()) {
      switch (t.bits()) {
        case 32: return "__popc";
        case 64: return "__popcll";
        default: return "";
      }
    }
    return "";
  }
};

71 72 73 74 75 76
struct CUDAShuffle {
  std::string operator()(Type t, std::string name) const {
    return "__shfl";
  }
};

77 78 79 80 81 82
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>);

83 84 85
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
.set_body(DispatchExtern<CUDAMath>);

86 87 88
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs")
.set_body(DispatchExtern<CUDAMath>);

89 90 91
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
.set_body(DispatchExtern<CUDAMath>);

92 93 94 95 96 97 98
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
99
.set_body(DispatchExtern<CUDAMath>);
100

ziheng committed
101 102 103
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
.set_body(DispatchExtern<CUDAMath>);

104 105 106
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
.set_body(DispatchExtern<CUDAMath>);

107 108 109
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

110 111 112
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);

113 114
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);
115

116 117 118
}  // namespace intrin
}  // namespace codegen
}  // namespace tvm