api_ir.cc 6.74 KB
Newer Older
tqchen committed
1 2 3
/*!
 *  Copyright (c) 2016 by Contributors
 *  Implementation of API functions related to IR build
4
 * \file api_ir.cc
tqchen committed
5 6
 */
#include <tvm/expr.h>
tqchen committed
7
#include <tvm/ir.h>
tqchen committed
8
#include <ir/IROperator.h>
9
#include <tvm/api_registry.h>
10
#include <tvm/ir_operator.h>
tqchen committed
11 12

namespace tvm {
tqchen committed
13
namespace ir {
tqchen committed
14

15
TVM_REGISTER_API("_Var")
16 17
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    *ret = Variable::make(args[1], args[0]);
tqchen committed
18 19
  });

20 21 22 23 24
TVM_REGISTER_API("make.abs")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    *ret = tvm::abs(args[0]);
  });

25 26 27 28 29
TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    *ret = Range::make_by_min_extent(args[0], args[1]);
  });

30
TVM_REGISTER_API("make.For")
31 32 33 34 35
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    *ret = For::make(args[0],
                     args[1],
                     args[2],
                     static_cast<ForType>(args[3].operator int()),
36
                     static_cast<HalideIR::DeviceAPI>(args[4].operator int()),
37
                     args[5]);
tqchen committed
38 39
  });

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    Type t = args[0];
    if (args.size() == 3) {
      *ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
    } else {
      *ret = Load::make(t, args[1], args[2], args[3]);
    }
  });

TVM_REGISTER_API("make.Store")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    Expr value = args[1];
    if (args.size() == 3) {
      *ret = Store::make(args[0], value, args[2], const_true(value.type().lanes()));
    } else {
      *ret = Store::make(args[0], value, args[2], args[3]);
    }
  });

60
TVM_REGISTER_API("make.Realize")
61 62 63 64 65 66 67
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    *ret = Realize::make(args[0],
                         args[1],
                         args[2],
                         args[3],
                         args[4],
                         args[5]);
68 69 70
  });


71
TVM_REGISTER_API("make.Call")
72 73 74 75 76 77 78
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    *ret = Call::make(args[0],
                      args[1],
                      args[2],
                      static_cast<Call::CallType>(args[3].operator int()),
                      args[4],
                      args[5]);
tqchen committed
79 80
  });

ziheng committed
81
TVM_REGISTER_API("make.CommReducer")
82 83 84 85 86
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = CommReducerNode::make(args[0],
                                 args[1],
                                 args[2],
                                 args[3]);
ziheng committed
87 88
  });

tqchen committed
89 90
// make from two arguments
#define REGISTER_MAKE1(Node)                                 \
91
  TVM_REGISTER_API("make."#Node)                             \
92 93
  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
      *ret = Node::make(args[0]);                            \
tqchen committed
94 95 96
    })                                                       \

#define REGISTER_MAKE2(Node)                                 \
97 98
  TVM_REGISTER_API("make."#Node)                             \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
99
      *ret = Node::make(args[0], args[1]);                   \
tqchen committed
100 101 102
    })                                                       \

#define REGISTER_MAKE3(Node)                                 \
103
  TVM_REGISTER_API("make."#Node)                             \
104 105
  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
      *ret = Node::make(args[0], args[1], args[2]);          \
tqchen committed
106 107
    })                                                       \

108
#define REGISTER_MAKE4(Node)                                            \
109
  TVM_REGISTER_API("make."#Node)                                        \
110
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
111 112
      *ret = Node::make(args[0], args[1], args[2], args[3]);            \
    })                                                                  \
tqchen committed
113

114 115 116 117 118 119
#define REGISTER_MAKE5(Node)                                            \
  TVM_REGISTER_API("make."#Node)                                        \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = Node::make(args[0], args[1], args[2], args[3], args[4]);   \
    })                                                                  \

120
#define REGISTER_MAKE_BINARY_OP(Node, Func)                  \
121
  TVM_REGISTER_API("make."#Node)                             \
122 123
  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
      Expr a = args[0], b = args[1];                         \
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
      *ret = (Func(a, b));                                   \
    })

#define REGISTER_MAKE_BIT_OP(Node, Func)                                \
  TVM_REGISTER_API("make."#Node)                                        \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      bool lhs_is_int = args[0].type_code() == kDLInt;                  \
      bool rhs_is_int = args[1].type_code() == kDLInt;                  \
      if (lhs_is_int) {                                                 \
        *ret = (Func(args[0].operator int(), args[1].operator Expr())); \
      } else if (rhs_is_int) {                                          \
        *ret = (Func(args[0].operator Expr(), args[1].operator int())); \
      } else {                                                          \
        *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
      }                                                                 \
139
    })
tqchen committed
140

141
REGISTER_MAKE5(Reduce);
tqchen committed
142 143
REGISTER_MAKE4(AttrStmt);

tqchen committed
144 145 146 147
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
REGISTER_MAKE_BINARY_OP(Add, operator+);
REGISTER_MAKE_BINARY_OP(Sub, operator-);
REGISTER_MAKE_BINARY_OP(Mul, operator*);
REGISTER_MAKE_BINARY_OP(Div, operator/);
REGISTER_MAKE_BINARY_OP(Mod, operator%);
REGISTER_MAKE_BINARY_OP(Min, min);
REGISTER_MAKE_BINARY_OP(Max, max);
REGISTER_MAKE_BINARY_OP(EQ, operator==);
REGISTER_MAKE_BINARY_OP(NE, operator!=);
REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GT, operator>);  // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
tqchen committed
168 169 170
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
ziheng committed
171
REGISTER_MAKE2(Cast);
tqchen committed
172 173
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
tqchen committed
174
REGISTER_MAKE3(LetStmt);
175
REGISTER_MAKE3(AssertStmt);
tqchen committed
176
REGISTER_MAKE3(ProducerConsumer);
177
REGISTER_MAKE5(Allocate);
178
REGISTER_MAKE4(Provide);
179
REGISTER_MAKE4(Prefetch);
tqchen committed
180 181 182 183 184
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);

tqchen committed
185
}  // namespace ir
tqchen committed
186
}  // namespace tvm