/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Implementation of API functions related to IR build * \file api_ir.cc */ #include <tvm/tir/expr.h> #include <tvm/tir/expr.h> #include <tvm/runtime/registry.h> #include <tvm/tir/op.h> namespace tvm { namespace tir { TVM_REGISTER_GLOBAL("_Var") .set_body_typed([](std::string s, DataType t) { return Var(s, t); }); TVM_REGISTER_GLOBAL("_SizeVar") .set_body_typed([](std::string s, DataType t) { return SizeVar(s, t); }); TVM_REGISTER_GLOBAL("make.abs") .set_body_typed(tvm::abs); TVM_REGISTER_GLOBAL("make.isnan") .set_body_typed(tvm::isnan); TVM_REGISTER_GLOBAL("make.floor") .set_body_typed(tvm::floor); TVM_REGISTER_GLOBAL("make.ceil") .set_body_typed(tvm::ceil); TVM_REGISTER_GLOBAL("make.round") .set_body_typed(tvm::round); TVM_REGISTER_GLOBAL("make.nearbyint") .set_body_typed(tvm::nearbyint); TVM_REGISTER_GLOBAL("make.trunc") .set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("make._cast") .set_body_typed(tvm::cast); TVM_REGISTER_GLOBAL("make._range_by_min_extent") .set_body_typed(Range::make_by_min_extent); TVM_REGISTER_GLOBAL("make.SeqStmt") .set_body_typed([](Array<Stmt> seq) { return SeqStmt(std::move(seq)); }); TVM_REGISTER_GLOBAL("make.For") .set_body_typed([]( Var loop_var, PrimExpr min, PrimExpr extent, int for_type, int device_api, Stmt body) { return ForNode::make(loop_var, min, extent, static_cast<ForType>(for_type), static_cast<DeviceAPI>(device_api), body); }); TVM_REGISTER_GLOBAL("make.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { DataType t = args[0]; if (args.size() == 3) { *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); } else { *ret = LoadNode::make(t, args[1], args[2], args[3]); } }); TVM_REGISTER_GLOBAL("make.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { PrimExpr value = args[1]; if (args.size() == 3) { *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); } else { *ret = StoreNode::make(args[0], value, args[2], args[3]); } }); TVM_REGISTER_GLOBAL("make.Realize") .set_body_typed(RealizeNode::make); TVM_REGISTER_GLOBAL("make.Call") .set_body_typed([]( DataType type, std::string name, Array<PrimExpr> args, int call_type, FunctionRef func, int value_index ) { return CallNode::make(type, name, args, static_cast<CallNode::CallType>(call_type), func, value_index); }); TVM_REGISTER_GLOBAL("make.CommReducer") .set_body_typed(CommReducerNode::make); // make from two arguments #define REGISTER_MAKE(NodeName) \ TVM_REGISTER_GLOBAL("make."#NodeName) \ .set_body_typed(NodeName ## Node::make); \ REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); REGISTER_MAKE(StringImm); REGISTER_MAKE(Add); REGISTER_MAKE(Sub); REGISTER_MAKE(Mul); REGISTER_MAKE(Div); REGISTER_MAKE(Mod); REGISTER_MAKE(FloorDiv); REGISTER_MAKE(FloorMod); REGISTER_MAKE(Min); REGISTER_MAKE(Max); REGISTER_MAKE(EQ); REGISTER_MAKE(NE); REGISTER_MAKE(LT); REGISTER_MAKE(LE); REGISTER_MAKE(GT); REGISTER_MAKE(GE); REGISTER_MAKE(And); REGISTER_MAKE(Or); REGISTER_MAKE(Not); REGISTER_MAKE(Select); REGISTER_MAKE(Ramp); REGISTER_MAKE(Cast); REGISTER_MAKE(Broadcast); REGISTER_MAKE(Shuffle); REGISTER_MAKE(Let); REGISTER_MAKE(LetStmt); REGISTER_MAKE(AssertStmt); REGISTER_MAKE(ProducerConsumer); REGISTER_MAKE(Provide); REGISTER_MAKE(Prefetch); REGISTER_MAKE(Free); REGISTER_MAKE(IfThenElse); REGISTER_MAKE(Evaluate); // overloaded, needs special handling // has default args TVM_REGISTER_GLOBAL("make.Allocate") .set_body_typed([]( Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body ){ return AllocateNode::make(buffer_var, type, extents, condition, body); }); // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("make."#Node) \ .set_body_typed([](PrimExpr a, PrimExpr b) { \ return (Func(a, b)); \ }) #define REGISTER_MAKE_BIT_OP(Node, Func) \ TVM_REGISTER_GLOBAL("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ bool lhs_is_int = args[0].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \ if (lhs_is_int) { \ *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ } else if (rhs_is_int) { \ *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ } else { \ *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ } \ }) REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpDiv, div); REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); REGISTER_MAKE_BINARY_OP(_OpPow, pow); REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); REGISTER_MAKE_BINARY_OP(_OpOr, operator||); REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); TVM_REGISTER_GLOBAL("make._OpIfThenElse") .set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { return if_then_else(cond, true_value, false_value); }); } // namespace tir } // namespace tvm