ir_mutator_test.cc 2.19 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>
23
#include <tvm/expr_operator.h>
24 25 26

namespace {
using namespace tvm::ir;
27 28
using namespace HalideIR::Internal;
using namespace HalideIR;
29 30 31 32 33 34

// replace variable to constant
class IRVar2Const : public IRMutator {
 public:
  VarExpr var;
  int int_val;
tqchen committed
35
  Expr Mutate(Expr expr) final {
36 37
    static const FMutateExpr& f = IRVar2Const::vtable_expr();
    return (f.can_dispatch(expr) ?
tqchen committed
38
            f(expr, expr, this) : IRMutator::Mutate(expr));
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
  }
  static FMutateExpr &vtable_expr();
};

// implement vtable
IRMutator::FMutateExpr &IRVar2Const::vtable_expr() {  // NOLINT(*)
  static FMutateExpr inst; return inst;
}

TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
    IRVar2Const* vm = static_cast<IRVar2Const*>(m);
    if (e.same_as(vm->var)) {
      return IntImm::make(Int(32), vm->int_val);
    } else {
      return e;
    }
  });

}  // namespace

TEST(IRMutator, Basic) {
61
  using namespace HalideIR::Internal;
62 63 64 65 66 67
  using namespace tvm;
  Var x("x"), y;
  auto z = x + y;
  IRVar2Const mu;
  mu.var = y;
  mu.int_val = 10;
tqchen committed
68
  auto zz = mu.Mutate(z);
69 70 71 72 73 74 75 76 77 78
  std::ostringstream os;
  os << zz;
  CHECK(os.str() == "(x + 10)");
}

int main(int argc, char ** argv) {
  testing::InitGoogleTest(&argc, argv);
  testing::FLAGS_gtest_death_test_style = "threadsafe";
  return RUN_ALL_TESTS();
}