#include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/ir_mutator.h> #include <tvm/ir_operator.h> namespace { using namespace tvm::ir; using namespace HalideIR::Internal; using namespace HalideIR; // replace variable to constant class IRVar2Const : public IRMutator { public: VarExpr var; int int_val; Expr Mutate(Expr expr) final { static const FMutateExpr& f = IRVar2Const::vtable_expr(); return (f.can_dispatch(expr) ? f(expr, expr, this) : IRMutator::Mutate(expr)); } static FMutateExpr &vtable_expr(); }; // implement vtable IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr) .set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) { IRVar2Const* vm = static_cast<IRVar2Const*>(m); if (e.same_as(vm->var)) { return IntImm::make(Int(32), vm->int_val); } else { return e; } }); } // namespace TEST(IRMutator, Basic) { using namespace HalideIR::Internal; using namespace tvm; Var x("x"), y; auto z = x + y; IRVar2Const mu; mu.var = y; mu.int_val = 10; auto zz = mu.Mutate(z); std::ostringstream os; os << zz; CHECK(os.str() == "(x + 10)"); } int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); }