api_pass.cc 4.91 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20
/*!
21
 *  Copyright (c) 2017 by Contributors
tqchen committed
22
 *  Exposre of pass functions.
23
 * \file api_pass.cc
tqchen committed
24 25 26
 */
#include <tvm/expr.h>
#include <tvm/ir.h>
27
#include <tvm/attrs.h>
tqchen committed
28
#include <tvm/ir_pass.h>
29
#include <tvm/ir_visitor.h>
30
#include <tvm/ir_mutator.h>
31
#include <tvm/api_registry.h>
tqchen committed
32 33 34 35

namespace tvm {
namespace ir {

36
TVM_REGISTER_API("ir_pass.Simplify")
37 38
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
39 40 41 42 43
      if (args.size() > 1) {
        *ret = Simplify(args[0].operator Stmt(), args[1]);
      } else {
        *ret = Simplify(args[0].operator Stmt());
      }
44
    } else {
45 46 47 48 49
      if (args.size() > 1) {
        *ret = Simplify(args[0].operator Expr(), args[1]);
      } else {
        *ret = Simplify(args[0].operator Expr());
      }
50 51 52
    }
  });

53 54 55
TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
56 57 58 59 60
      if (args.size() > 1) {
        *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
      } else {
        *ret = CanonicalSimplify(args[0].operator Stmt());
      }
61
    } else {
62 63 64 65 66
      if (args.size() > 1) {
        *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
      } else {
        *ret = CanonicalSimplify(args[0].operator Expr());
      }
67 68 69
    }
  });

70 71 72 73 74 75 76 77 78
TVM_REGISTER_API("ir_pass.Substitute")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
      *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>());
    } else {
      *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>());
    }
  });

79
TVM_REGISTER_API("ir_pass.Equal")
80 81 82
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
      *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
83
    } else {
84
      *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
85 86 87
    }
  });

88 89 90 91 92 93 94 95
TVM_REGISTER_API("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args.size() <= 3) {
      *ret = StorageFlatten(args[0], args[1], args[2]);
    } else {
      *ret = StorageFlatten(args[0], args[1], args[2], args[3]);
    }
  });
96 97

TVM_REGISTER_API("ir_pass.AttrsEqual")
98 99 100
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
    return AttrsEqual()(lhs, rhs);
  });
101 102

TVM_REGISTER_API("ir_pass.AttrsHash")
103 104 105
.set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
    return AttrsHash()(node);
  });
106 107


108 109 110 111 112
TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
  });

113
TVM_REGISTER_API("ir_pass.PostOrderVisit")
114 115 116 117 118 119 120
.set_body([](TVMArgs args, TVMRetValue *ret) {
    PackedFunc f = args[1];
    ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
        f(n);
      });
  });

tqchen committed
121
// make from two arguments
122
#define REGISTER_PASS(PassName)                                   \
123
  TVM_REGISTER_API("ir_pass."#PassName)                           \
124 125 126 127 128 129 130 131 132
  .set_body_typed(PassName);                                     \


REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(RewriteUnsafeSelect);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VectorizeLoop);
133
REGISTER_PASS(SkipVectorize);
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(ThreadSync);
REGISTER_PASS(MakeAPI);
REGISTER_PASS(BindDeviceType);
REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(SplitPipeline);
REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(NarrowChannelAccess);
REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(LowerWarpMemory);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerIntrin);
155
REGISTER_PASS(LowerCustomDatatypes);
156 157 158 159 160 161
REGISTER_PASS(LowerTVMBuiltin);
REGISTER_PASS(CombineContextCall);
REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
tqchen committed
162 163
}  // namespace ir
}  // namespace tvm