api_ir.cc 6.67 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20 21 22
/*!
 *  Copyright (c) 2016 by Contributors
 *  Implementation of API functions related to IR build
23
 * \file api_ir.cc
tqchen committed
24 25
 */
#include <tvm/expr.h>
tqchen committed
26
#include <tvm/ir.h>
27
#include <tvm/api_registry.h>
28
#include <tvm/expr_operator.h>
tqchen committed
29 30

namespace tvm {
tqchen committed
31
namespace ir {
tqchen committed
32

33
TVM_REGISTER_API("_Var")
34 35
.set_body_typed<VarExpr(std::string, Type)>([](std::string s, Type t) {
    return Variable::make(t, s);
tqchen committed
36 37
  });

38
TVM_REGISTER_API("make.abs")
39
.set_body_typed(tvm::abs);
40

41
TVM_REGISTER_API("make.floor")
42
.set_body_typed(tvm::floor);
43 44

TVM_REGISTER_API("make.ceil")
45
.set_body_typed(tvm::ceil);
46 47

TVM_REGISTER_API("make.round")
48
.set_body_typed(tvm::round);
49 50

TVM_REGISTER_API("make.trunc")
51
.set_body_typed(tvm::trunc);
52 53

TVM_REGISTER_API("make._cast")
54
.set_body_typed(tvm::cast);
55

56
TVM_REGISTER_API("make._range_by_min_extent")
57
.set_body_typed(Range::make_by_min_extent);
58

59
TVM_REGISTER_API("make.For")
60 61 62 63 64 65 66 67 68 69 70
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
  VarExpr loop_var, Expr min, Expr extent,
  int for_type, int device_api, Stmt body
) {
  return For::make(loop_var,
                    min,
                    extent,
                    static_cast<ForType>(for_type),
                    static_cast<HalideIR::DeviceAPI>(device_api),
                    body);
});
tqchen committed
71

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    Type t = args[0];
    if (args.size() == 3) {
      *ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
    } else {
      *ret = Load::make(t, args[1], args[2], args[3]);
    }
  });

TVM_REGISTER_API("make.Store")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    Expr value = args[1];
    if (args.size() == 3) {
      *ret = Store::make(args[0], value, args[2], const_true(value.type().lanes()));
    } else {
      *ret = Store::make(args[0], value, args[2], args[3]);
    }
  });

92
TVM_REGISTER_API("make.Realize")
93
.set_body_typed(Realize::make);
94

95
TVM_REGISTER_API("make.Call")
96 97 98 99 100 101 102 103 104 105 106 107
.set_body_typed<Expr(Type, std::string, Array<Expr>, int, FunctionRef, int)>([](
  Type type, std::string name,
  Array<Expr> args, int call_type,
  FunctionRef func, int value_index
) {
  return Call::make(type,
                    name,
                    args,
                    static_cast<Call::CallType>(call_type),
                    func,
                    value_index);
});
tqchen committed
108

ziheng committed
109
TVM_REGISTER_API("make.CommReducer")
110
.set_body_typed(CommReducerNode::make);
ziheng committed
111

tqchen committed
112
// make from two arguments
113
#define REGISTER_MAKE(Node)                                  \
114
  TVM_REGISTER_API("make."#Node)                             \
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
  .set_body_typed(Node::make);                              \

REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);

REGISTER_MAKE(IntImm);
REGISTER_MAKE(UIntImm);
REGISTER_MAKE(FloatImm);
REGISTER_MAKE(StringImm);

REGISTER_MAKE(Add);
REGISTER_MAKE(Sub);
REGISTER_MAKE(Mul);
REGISTER_MAKE(Div);
REGISTER_MAKE(Mod);
REGISTER_MAKE(Min);
REGISTER_MAKE(Max);
REGISTER_MAKE(EQ);
REGISTER_MAKE(NE);
REGISTER_MAKE(LT);
REGISTER_MAKE(LE);
REGISTER_MAKE(GT);
REGISTER_MAKE(GE);
REGISTER_MAKE(And);
REGISTER_MAKE(Or);

REGISTER_MAKE(Not);
REGISTER_MAKE(Select);
REGISTER_MAKE(Ramp);
REGISTER_MAKE(Cast);
REGISTER_MAKE(Broadcast);
REGISTER_MAKE(Shuffle);
REGISTER_MAKE(Let);
REGISTER_MAKE(LetStmt);
REGISTER_MAKE(AssertStmt);
REGISTER_MAKE(ProducerConsumer);
REGISTER_MAKE(Provide);
REGISTER_MAKE(Prefetch);
REGISTER_MAKE(Free);
REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);

// overloaded, needs special handling
TVM_REGISTER_API("make.Block")
  .set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));

// has default args
TVM_REGISTER_API("make.Allocate")
  .set_body_typed<Stmt(VarExpr, Type, Array<Expr>, Expr, Stmt)>([](
    VarExpr buffer_var, Type type, Array<Expr> extents, Expr condition, Stmt body
  ){
    return Allocate::make(buffer_var, type, extents, condition, body);
  });
168 169

// operator overloading, smarter than make
170
#define REGISTER_MAKE_BINARY_OP(Node, Func)                  \
171
  TVM_REGISTER_API("make."#Node)                             \
172 173
  .set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) {     \
      return (Func(a, b));                                   \
174 175 176 177 178 179 180 181 182 183 184 185 186 187
    })

#define REGISTER_MAKE_BIT_OP(Node, Func)                                \
  TVM_REGISTER_API("make."#Node)                                        \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      bool lhs_is_int = args[0].type_code() == kDLInt;                  \
      bool rhs_is_int = args[1].type_code() == kDLInt;                  \
      if (lhs_is_int) {                                                 \
        *ret = (Func(args[0].operator int(), args[1].operator Expr())); \
      } else if (rhs_is_int) {                                          \
        *ret = (Func(args[0].operator Expr(), args[1].operator int())); \
      } else {                                                          \
        *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
      }                                                                 \
188
    })
tqchen committed
189

tqchen committed
190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGT, operator>);  // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
206 207 208 209 210
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
tqchen committed
211

tqchen committed
212
}  // namespace ir
tqchen committed
213
}  // namespace tvm