Commit 162ed02c by tqchen

Add new functor

parent 0153649e
Subproject commit bd94f8c8e41b46ae7ca69a3405aac7463a4e23d5 Subproject commit 7906ae1edea96e416e338ea21b8bc248d1d6411c
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file ir_node.h * \file ir.h
* \brief Additional high level nodes in the IR * \brief Additional high level nodes in the IR
*/ */
#ifndef TVM_IR_NODE_H_ #ifndef TVM_IR_H_
#define TVM_IR_NODE_H_ #define TVM_IR_H_
#include <ir/Expr.h> #include <ir/Expr.h>
#include <ir/IR.h> #include <ir/IR.h>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
*/ */
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_node.h> #include <tvm/ir.h>
#include <ir/IR.h> #include <ir/IR.h>
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <memory> #include <memory>
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_node.h>
TEST(IRF, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x");
auto z = x + 1;
IRFunctor<int(const IRNodeRef& n, int b)> f;
LOG(INFO) << "x";
f.set_dispatch<Variable>([](const Variable* n, int b) {
return b;
});
f.set_dispatch<Add>([](const Add* n, int b) {
return b + 2;
});
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 4);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment