/*!
 *  Copyright (c) 2017 by Contributors
 *  Exposre of pass functions.
 * \file api_pass.cc
 */
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>

namespace tvm {
namespace ir {

TVM_REGISTER_API("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
      *ret = Simplify(args[0].operator Stmt());
    } else {
      *ret = Simplify(args[0].operator Expr());
    }
  });

TVM_REGISTER_API("ir_pass.CanonicalSimplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
      *ret = CanonicalSimplify(args[0].operator Stmt());
    } else {
      *ret = CanonicalSimplify(args[0].operator Expr());
    }
  });

TVM_REGISTER_API("ir_pass.Equal")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args[0].IsNodeType<Stmt>()) {
      *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
    } else {
      *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
    }
  });

TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
  });

TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    PackedFunc f = args[1];
    ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
        f(n);
      });
  });

// make from two arguments
#define REGISTER_PASS1(PassName)                                  \
  TVM_REGISTER_API("ir_pass."#PassName)                           \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
      *ret = PassName(args[0]);                                   \
    })                                                            \

#define REGISTER_PASS2(PassName)                                  \
  TVM_REGISTER_API("ir_pass."#PassName)                           \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
      *ret = PassName(args[0], args[1]);                          \
    })                                                            \

#define REGISTER_PASS3(PassName)                                        \
  TVM_REGISTER_API("ir_pass."#PassName)                                 \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = PassName(args[0], args[1], args[2]);                       \
    })                                                                  \

#define REGISTER_PASS4(PassName)                                        \
  TVM_REGISTER_API("ir_pass."#PassName)                                 \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = PassName(args[0], args[1], args[2], args[3]);              \
    })                                                                  \

#define REGISTER_PASS5(PassName)                                        \
  TVM_REGISTER_API("ir_pass."#PassName)                                 \
  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
      *ret = PassName(args[0], args[1], args[2], args[3], args[4]);     \
    })                                                                  \

REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerPackedCall);
}  // namespace ir
}  // namespace tvm