ir_mutator_test.cc 1.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>

namespace {
using namespace tvm::ir;
using namespace Halide::Internal;
using namespace Halide;

// replace variable to constant
class IRVar2Const : public IRMutator {
 public:
  VarExpr var;
  int int_val;
tqchen committed
15
  Expr Mutate(Expr expr) final {
16 17
    static const FMutateExpr& f = IRVar2Const::vtable_expr();
    return (f.can_dispatch(expr) ?
tqchen committed
18
            f(expr, expr, this) : IRMutator::Mutate(expr));
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
  }
  static FMutateExpr &vtable_expr();
};

// implement vtable
IRMutator::FMutateExpr &IRVar2Const::vtable_expr() {  // NOLINT(*)
  static FMutateExpr inst; return inst;
}

TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
    IRVar2Const* vm = static_cast<IRVar2Const*>(m);
    if (e.same_as(vm->var)) {
      return IntImm::make(Int(32), vm->int_val);
    } else {
      return e;
    }
  });

}  // namespace

TEST(IRMutator, Basic) {
  using namespace Halide::Internal;
  using namespace tvm;
  Var x("x"), y;
  auto z = x + y;
  IRVar2Const mu;
  mu.var = y;
  mu.int_val = 10;
tqchen committed
48
  auto zz = mu.Mutate(z);
49 50 51 52 53 54 55 56 57 58
  std::ostringstream os;
  os << zz;
  CHECK(os.str() == "(x + 10)");
}

int main(int argc, char ** argv) {
  testing::InitGoogleTest(&argc, argv);
  testing::FLAGS_gtest_death_test_style = "threadsafe";
  return RUN_ALL_TESTS();
}