/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Exposure of pass functions. * \file api_pass.cc */ #include <tvm/tir/expr.h> #include <tvm/tir/stmt.h> #include <tvm/ir/attrs.h> #include <tvm/tir/ir_pass.h> #include <tvm/tir/expr_functor.h> #include <tvm/tir/stmt_functor.h> #include <tvm/runtime/registry.h> namespace tvm { namespace tir { TVM_REGISTER_GLOBAL("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef<Stmt>()) { if (args.size() > 1) { *ret = Simplify(args[0].operator Stmt(), args[1]); } else { *ret = Simplify(args[0].operator Stmt()); } } else { if (args.size() > 1) { *ret = Simplify(args[0].operator PrimExpr(), args[1]); } else { *ret = Simplify(args[0].operator PrimExpr()); } } }); TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef<Stmt>()) { if (args.size() > 1) { *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); } else { *ret = CanonicalSimplify(args[0].operator Stmt()); } } else { if (args.size() > 1) { *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]); } else { *ret = CanonicalSimplify(args[0].operator PrimExpr()); } } }); TVM_REGISTER_GLOBAL("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef<Stmt>()) { *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>()); } else { *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>()); } }); TVM_REGISTER_GLOBAL("ir_pass.Equal") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef<Stmt>()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr()); } }); TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() <= 3) { *ret = StorageFlatten(args[0], args[1], args[2]); } else { *ret = StorageFlatten(args[0], args[1], args[2], args[3]); } }); TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") .set_body_typed ([](const Stmt& stmt, const te::Schedule& schedule, const Map<te::Tensor, Buffer>& extern_buffer) { return RewriteForTensorCore(stmt, schedule, extern_buffer); }); TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual") .set_body_typed( [](const ObjectRef& lhs, const ObjectRef& rhs) { return AttrsEqual()(lhs, rhs); }); TVM_REGISTER_GLOBAL("ir_pass.AttrsHash") .set_body_typed([](const ObjectRef &node) -> int64_t { return AttrsHash()(node); }); TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); }); TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc f = args[1]; tir::PostOrderVisit(args[0], [f](const ObjectRef& n) { f(n); }); }); TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess") .set_body([](TVMArgs args, TVMRetValue *ret) { LoweredFunc f = args[0]; auto n = make_object<LoweredFuncNode>(*f.operator->()); n->body = LowerStorageAccessInfo(f->body); *ret = LoweredFunc(n); }); // make from two arguments #define REGISTER_PASS(PassName) \ TVM_REGISTER_GLOBAL("ir_pass."#PassName) \ .set_body_typed(PassName); \ REGISTER_PASS(ConvertSSA); REGISTER_PASS(VerifySSA); REGISTER_PASS(RewriteUnsafeSelect); REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); REGISTER_PASS(VectorizeLoop); REGISTER_PASS(SkipVectorize); REGISTER_PASS(UnrollLoop); REGISTER_PASS(InjectCopyIntrin); REGISTER_PASS(ThreadSync); REGISTER_PASS(MakeAPI); REGISTER_PASS(BindDeviceType); REGISTER_PASS(SplitHostDevice); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); REGISTER_PASS(LowerDeviceStorageAccessInfo) REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectDoubleBuffer); REGISTER_PASS(LoopPartition); REGISTER_PASS(RemoveNoOp); REGISTER_PASS(LiftAttrScope); REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerIntrin); REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerTVMBuiltin); REGISTER_PASS(CombineContextCall); REGISTER_PASS(VerifyMemory); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(InferFragment) } // namespace tir } // namespace tvm