/*! * Copyright (c) 2016 by Contributors * Implementation of API functions related to IR build * \file api_ir.cc */ #include <tvm/expr.h> #include <tvm/ir.h> #include <tvm/ir_operator.h> #include <tvm/api_registry.h> #include <tvm/ir_operator.h> namespace tvm { namespace ir { TVM_REGISTER_API("_Var") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = Variable::make(args[1], args[0]); }); TVM_REGISTER_API("make.abs") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = tvm::abs(args[0]); }); TVM_REGISTER_API("make._range_by_min_extent") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = Range::make_by_min_extent(args[0], args[1]); }); TVM_REGISTER_API("make.For") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = For::make(args[0], args[1], args[2], static_cast<ForType>(args[3].operator int()), static_cast<HalideIR::DeviceAPI>(args[4].operator int()), args[5]); }); TVM_REGISTER_API("make.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { Type t = args[0]; if (args.size() == 3) { *ret = Load::make(t, args[1], args[2], const_true(t.lanes())); } else { *ret = Load::make(t, args[1], args[2], args[3]); } }); TVM_REGISTER_API("make.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr value = args[1]; if (args.size() == 3) { *ret = Store::make(args[0], value, args[2], const_true(value.type().lanes())); } else { *ret = Store::make(args[0], value, args[2], args[3]); } }); TVM_REGISTER_API("make.Realize") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = Realize::make(args[0], args[1], args[2], args[3], args[4], args[5]); }); TVM_REGISTER_API("make.Call") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = Call::make(args[0], args[1], args[2], static_cast<Call::CallType>(args[3].operator int()), args[4], args[5]); }); TVM_REGISTER_API("make.CommReducer") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = CommReducerNode::make(args[0], args[1], args[2], args[3]); }); // make from two arguments #define REGISTER_MAKE1(Node) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ *ret = Node::make(args[0]); \ }) \ #define REGISTER_MAKE2(Node) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ *ret = Node::make(args[0], args[1]); \ }) \ #define REGISTER_MAKE3(Node) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ *ret = Node::make(args[0], args[1], args[2]); \ }) \ #define REGISTER_MAKE4(Node) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ *ret = Node::make(args[0], args[1], args[2], args[3]); \ }) \ #define REGISTER_MAKE5(Node) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ }) \ REGISTER_MAKE5(Reduce); REGISTER_MAKE4(AttrStmt); REGISTER_MAKE2(IntImm); REGISTER_MAKE2(UIntImm); REGISTER_MAKE2(FloatImm); REGISTER_MAKE1(StringImm); REGISTER_MAKE2(Add); REGISTER_MAKE2(Sub); REGISTER_MAKE2(Mul); REGISTER_MAKE2(Div); REGISTER_MAKE2(Mod); REGISTER_MAKE2(Min); REGISTER_MAKE2(Max); REGISTER_MAKE2(EQ); REGISTER_MAKE2(NE); REGISTER_MAKE2(LT); REGISTER_MAKE2(LE); REGISTER_MAKE2(GT); REGISTER_MAKE2(GE); REGISTER_MAKE2(And); REGISTER_MAKE2(Or); REGISTER_MAKE1(Not); REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); REGISTER_MAKE2(Cast); REGISTER_MAKE2(Broadcast); REGISTER_MAKE2(Shuffle); REGISTER_MAKE3(Let); REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(AssertStmt); REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE5(Allocate); REGISTER_MAKE4(Provide); REGISTER_MAKE4(Prefetch); REGISTER_MAKE1(Free); REGISTER_MAKE2(Block); REGISTER_MAKE3(IfThenElse); REGISTER_MAKE1(Evaluate); // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ Expr a = args[0], b = args[1]; \ *ret = (Func(a, b)); \ }) #define REGISTER_MAKE_BIT_OP(Node, Func) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ bool lhs_is_int = args[0].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \ if (lhs_is_int) { \ *ret = (Func(args[0].operator int(), args[1].operator Expr())); \ } else if (rhs_is_int) { \ *ret = (Func(args[0].operator Expr(), args[1].operator int())); \ } else { \ *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \ } \ }) REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); REGISTER_MAKE_BINARY_OP(_OpOr, operator||); REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); } // namespace ir } // namespace tvm