Commit f2f1526d by Haichen Shen Committed by Tianqi Chen

[PASS] Export simplify and equal to python (#14)

* [PASS] Export simplify and equal to python

* fix naming convention
parent 7e025234
......@@ -9,6 +9,8 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <pass/Simplify.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
......@@ -19,6 +21,21 @@
namespace tvm {
namespace ir {
inline bool Equal(Expr a, Expr b) {
return Halide::Internal::equal(a, b);
}
inline bool Equal(Stmt a, Stmt b) {
return Halide::Internal::equal(a, b);
}
inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}
inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}
/*!
* \brief Schedule s' dependent operations.
......
......@@ -13,6 +13,27 @@ namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Expr());
} else {
*ret = Simplify(args.at(0).operator Stmt());
}
});
TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
} else {
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
}
});
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
......
import tvm
def test_simplify():
x = tvm.Var('x')
e1 = tvm.ir_pass.Simplify(x + 2 + 1)
assert(tvm.ir_pass.Equal(e1, x + 3))
e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
def test_verify_ssa():
x = tvm.Var('x')
y = tvm.Var()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment