Commit f2f1526d by Haichen Shen Committed by Tianqi Chen

[PASS] Export simplify and equal to python (#14)

* [PASS] Export simplify and equal to python

* fix naming convention
parent 7e025234
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#ifndef TVM_IR_PASS_H_ #ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <pass/Simplify.h>
#include <tvm/ir_functor.h> #include <tvm/ir_functor.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -19,6 +21,21 @@ ...@@ -19,6 +21,21 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
inline bool Equal(Expr a, Expr b) {
return Halide::Internal::equal(a, b);
}
inline bool Equal(Stmt a, Stmt b) {
return Halide::Internal::equal(a, b);
}
inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}
inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}
/*! /*!
* \brief Schedule s' dependent operations. * \brief Schedule s' dependent operations.
......
...@@ -13,6 +13,27 @@ namespace ir { ...@@ -13,6 +13,27 @@ namespace ir {
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Expr());
} else {
*ret = Simplify(args.at(0).operator Stmt());
}
});
TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
} else {
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
}
});
// make from two arguments // make from two arguments
#define REGISTER_PASS1(PassName) \ #define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API(_pass_## PassName) \
......
import tvm import tvm
def test_simplify():
x = tvm.Var('x')
e1 = tvm.ir_pass.Simplify(x + 2 + 1)
assert(tvm.ir_pass.Equal(e1, x + 3))
e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
def test_verify_ssa(): def test_verify_ssa():
x = tvm.Var('x') x = tvm.Var('x')
y = tvm.Var() y = tvm.Var()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment