api_pass.cc 4.69 KB
Newer Older
tqchen committed
1
/*!
2
 *  Copyright (c) 2017 by Contributors
tqchen committed
3
 *  Exposre of pass functions.
4
 * \file api_pass.cc
tqchen committed
5 6 7 8
 */
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
9
#include <tvm/ir_visitor.h>
10
#include <tvm/ir_mutator.h>
11
#include <tvm/api_registry.h>
tqchen committed
12 13 14 15

namespace tvm {
namespace ir {

16
TVM_REGISTER_API("ir_pass.Simplify")
17 18
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
19 20 21 22 23
      if (args.size() > 1) {
        *ret = Simplify(args[0].operator Stmt(), args[1]);
      } else {
        *ret = Simplify(args[0].operator Stmt());
      }
24
    } else {
25 26 27 28 29
      if (args.size() > 1) {
        *ret = Simplify(args[0].operator Expr(), args[1]);
      } else {
        *ret = Simplify(args[0].operator Expr());
      }
30 31 32
    }
  });

33 34 35
TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
36 37 38 39 40
      if (args.size() > 1) {
        *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
      } else {
        *ret = CanonicalSimplify(args[0].operator Stmt());
      }
41
    } else {
42 43 44 45 46
      if (args.size() > 1) {
        *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
      } else {
        *ret = CanonicalSimplify(args[0].operator Expr());
      }
47 48 49
    }
  });

50
TVM_REGISTER_API("ir_pass.Equal")
51 52 53
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
      *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
54
    } else {
55
      *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
56 57 58
    }
  });

59 60 61 62 63
TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
  });

64
TVM_REGISTER_API("ir_pass.PostOrderVisit")
65 66 67 68 69 70 71
.set_body([](TVMArgs args, TVMRetValue *ret) {
    PackedFunc f = args[1];
    ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
        f(n);
      });
  });

tqchen committed
72 73
// make from two arguments
#define REGISTER_PASS1(PassName)                                  \
74
  TVM_REGISTER_API("ir_pass."#PassName)                           \
75 76
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
      *ret = PassName(args[0]);                                   \
tqchen committed
77 78
    })                                                            \

79
#define REGISTER_PASS2(PassName)                                  \
80
  TVM_REGISTER_API("ir_pass."#PassName)                           \
81 82
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
      *ret = PassName(args[0], args[1]);                          \
83 84
    })                                                            \

85 86 87 88 89 90
#define REGISTER_PASS3(PassName)                                        \
  TVM_REGISTER_API("ir_pass."#PassName)                                 \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = PassName(args[0], args[1], args[2]);                       \
    })                                                                  \

tqchen committed
91
#define REGISTER_PASS4(PassName)                                        \
92
  TVM_REGISTER_API("ir_pass."#PassName)                                 \
93 94
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = PassName(args[0], args[1], args[2], args[3]);              \
tqchen committed
95 96
    })                                                                  \

97 98 99 100 101 102
#define REGISTER_PASS5(PassName)                                        \
  TVM_REGISTER_API("ir_pass."#PassName)                                 \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = PassName(args[0], args[1], args[2], args[3], args[4]);     \
    })                                                                  \

tqchen committed
103 104
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
105
REGISTER_PASS1(RewriteUnsafeSelect);
tqchen committed
106
REGISTER_PASS4(Inline);
107
REGISTER_PASS3(StorageFlatten);
108
REGISTER_PASS4(IRTransform);
109
REGISTER_PASS1(VectorizeLoop);
110
REGISTER_PASS5(UnrollLoop);
111
REGISTER_PASS3(InjectCopyIntrin);
112
REGISTER_PASS2(ThreadSync);
113
REGISTER_PASS5(MakeAPI);
114
REGISTER_PASS2(BindDeviceType);
115
REGISTER_PASS1(SplitHostDevice);
116
REGISTER_PASS1(StorageRewrite);
117
REGISTER_PASS1(CoProcSync);
118
REGISTER_PASS1(LowerStorageAccessInfo);
119
REGISTER_PASS1(InjectVirtualThread);
120
REGISTER_PASS1(InjectPrefetch);
121
REGISTER_PASS2(InjectDoubleBuffer);
122
REGISTER_PASS2(LoopPartition);
Tianqi Chen committed
123
REGISTER_PASS1(RemoveNoOp);
124
REGISTER_PASS2(SplitPipeline);
125
REGISTER_PASS2(LiftAttrScope);
126
REGISTER_PASS1(NarrowChannelAccess);
127
REGISTER_PASS2(LowerThreadAllreduce);
128
REGISTER_PASS2(LowerWarpMemory);
129
REGISTER_PASS2(RemapThreadAxis);
130
REGISTER_PASS2(LowerIntrin);
131
REGISTER_PASS1(LowerTVMBuiltin);
132
REGISTER_PASS1(CombineContextCall);
133
REGISTER_PASS2(VerifyMemory);
134
REGISTER_PASS2(VerifyGPUCode);
tqchen committed
135 136
}  // namespace ir
}  // namespace tvm