Commit 3c0dc79d by tqchen

Simplify for cxx

parent 9595a9c1
...@@ -16,9 +16,7 @@ namespace tvm { ...@@ -16,9 +16,7 @@ namespace tvm {
* \param src The source expression * \param src The source expression
* \return the simplified expression. * \return the simplified expression.
*/ */
inline Expr Simplify(Expr src) { Expr Simplify(Expr src);
return src;
}
/*! /*!
* \brief visit the exression node in expr tree in post DFS order. * \brief visit the exression node in expr tree in post DFS order.
......
...@@ -12,5 +12,6 @@ ...@@ -12,5 +12,6 @@
#include "./tensor.h" #include "./tensor.h"
#include "./domain.h" #include "./domain.h"
#include "./array.h" #include "./array.h"
#include "./expr_util.h"
#endif // TVM_TVM_H_ #endif // TVM_TVM_H_
...@@ -172,6 +172,7 @@ def register_node(type_key): ...@@ -172,6 +172,7 @@ def register_node(type_key):
""" """
def register(cls): def register(cls):
NODE_TYPE[type_key] = cls NODE_TYPE[type_key] = cls
return cls
return register return register
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/op.h> #include <tvm/op.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/expr_util.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace dmlc { namespace dmlc {
...@@ -104,6 +105,11 @@ TVM_REGISTER_API(_TensorInput) ...@@ -104,6 +105,11 @@ TVM_REGISTER_API(_TensorInput)
static_cast<DataType>(static_cast<int>(args.at(1)))); static_cast<DataType>(static_cast<int>(args.at(1))));
}); });
TVM_REGISTER_API(simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Simplify(args.at(0));
});
// transformations // transformations
TVM_REGISTER_API(format_str) TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
......
...@@ -9,63 +9,6 @@ ...@@ -9,63 +9,6 @@
namespace tvm { namespace tvm {
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
}
switch (this->node_type()) {
case kVarNode: {
os << Get<VarNode>()->name; return;
}
case kIntNode: {
os << Get<IntNode>()->value; return;
}
case kFloatNode: {
os << Get<FloatNode>()->value; return;
}
case kBinaryOpNode: {
const auto* n = Get<BinaryOpNode>();
const char* fname = n->op->FunctionName();
if (fname[1] == '\0' && !isalpha(fname[0])) {
os << '(';
n->lhs.Print(os);
os << ' ' << fname[0] << ' ';
n->rhs.Print(os);
os << ')';
} else {
os << fname << '(';
n->lhs.Print(os);
os << ", ";
n->rhs.Print(os);
os << ')';
}
return;
}
case kUnaryOpNode: {
const auto* n = Get<UnaryOpNode>();
os << n->op->FunctionName() << '(';
n->src.Print(os);
os << ')';
return;
}
case kReduceNode: {
const auto* n = Get<ReduceNode>();
os << "reduce("<< n->op->FunctionName() << ", ";
n->src.Print(os);
os << ", " << n->rdom << ')';
return;
}
case kTensorReadNode: {
const auto* n = Get<TensorReadNode>();
os << n->tensor.name() << n->indices;
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
}
}
Var::Var(std::string name, DataType dtype) { Var::Var(std::string name, DataType dtype) {
auto node = std::make_shared<VarNode>(); auto node = std::make_shared<VarNode>();
node->name = std::move(name); node->name = std::move(name);
......
...@@ -3,8 +3,205 @@ ...@@ -3,8 +3,205 @@
* \file expr_util.cc * \file expr_util.cc
*/ */
#include <tvm/expr_util.h> #include <tvm/expr_util.h>
#include <tvm/op.h>
namespace tvm { namespace tvm {
inline bool is_ingeter(DataType t) {
return t == kInt32;
}
/*! \brief Canonical form of expression */
struct CanonicalExpr {
/*! \brief the e->value */
std::unordered_map<Expr, int64_t> dict;
/*! \brief constant value in the expresssion */
int64_t constant{0};
// change CanonicalExpr as expr
inline Expr AsExpr() const {
Expr e;
using KV = std::pair<Expr, int64_t>;
std::vector<KV> tlist(dict.begin(), dict.end());
std::sort(tlist.begin(), tlist.end(), [](const KV& lhs, const KV& rhs) {
return lhs.first.hash() < rhs.first.hash();
});
for (auto &kv : tlist) {
if (kv.second == 0) continue;
Expr tmp;
if (kv.second == 1) {
tmp = kv.first;
} else {
tmp = kv.first * kv.second;
}
if (e.is_null()) {
e = tmp;
} else {
e = e + tmp;
}
}
if (e.is_null()) {
return IntConstant(constant);
} else {
if (constant != 0) e = e + constant;
return e;
}
}
inline void Add(const Expr& e, int beta) {
auto it = dict.find(e);
if (it != dict.end()) {
it->second += beta;
if (it->second == 0) dict.erase(it);
} else {
dict[e] = beta;
}
}
};
// out += beta * Canonicalize(e)
void AddCanonical(const Expr& e,
CanonicalExpr* out,
int beta) {
static const BinaryOp* add_op = BinaryOp::Get("+");
static const BinaryOp* sub_op = BinaryOp::Get("-");
static const BinaryOp* mul_op = BinaryOp::Get("*");
static const BinaryOp* max_op = BinaryOp::Get("max");
static const BinaryOp* min_op = BinaryOp::Get("min");
CHECK(!e.is_null()) << "cannot simplify null";
switch (e.node_type()) {
case kIntNode: {
out->constant += (e.Get<IntNode>()->value) * beta; return;
}
case kBinaryOpNode: {
const auto* n = e.Get<BinaryOpNode>();
if (n->op == add_op) {
AddCanonical(n->lhs, out, beta);
AddCanonical(n->rhs, out, beta);
return;
}
if (n->op == sub_op) {
AddCanonical(n->lhs, out, beta);
AddCanonical(n->rhs, out, -beta);
return;
}
if (n->op == mul_op) {
if (n->lhs.node_type() == kIntNode) {
AddCanonical(n->rhs, out, beta * (n->lhs.Get<IntNode>()->value)); return;
} else if (n->rhs.node_type() == kIntNode) {
AddCanonical(n->lhs, out, beta * (n->rhs.Get<IntNode>()->value)); return;
}
CanonicalExpr clhs, crhs;
AddCanonical(n->lhs, &clhs, 1);
if (clhs.dict.size() == 0) {
AddCanonical(n->rhs, out, beta * clhs.constant); return;
}
AddCanonical(n->rhs, &crhs, 1);
if (crhs.dict.size() == 0) {
AddCanonical(n->lhs, out, beta * crhs.constant); return;
}
out->Add(e, beta); return;
}
if (n->op == max_op) {
CanonicalExpr res;
AddCanonical(n->lhs, &res, 1);
AddCanonical(n->rhs, &res, -1);
if (res.dict.size() == 0) {
if (res.constant > 0) {
AddCanonical(n->lhs, out, beta); return;
} else {
AddCanonical(n->rhs, out, beta); return;
}
} else {
out->Add(e, beta); return;
}
}
if (n->op == min_op) {
CanonicalExpr res;
AddCanonical(n->lhs, &res, 1);
AddCanonical(n->rhs, &res, -1);
if (res.dict.size() == 0) {
if (res.constant <= 0) {
AddCanonical(n->lhs, out, beta); return;
} else {
AddCanonical(n->rhs, out, beta); return;
}
} else {
out->Add(e, beta); return;
}
}
out->Add(e, beta);
return;
}
default: {
out->Add(e, beta); return;
}
}
}
Expr Simplify(Expr src) {
CanonicalExpr cexpr;
AddCanonical(src, &cexpr, 1);
return cexpr.AsExpr();
}
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
}
switch (this->node_type()) {
case kVarNode: {
os << Get<VarNode>()->name; return;
}
case kIntNode: {
os << Get<IntNode>()->value; return;
}
case kFloatNode: {
os << Get<FloatNode>()->value; return;
}
case kBinaryOpNode: {
const auto* n = Get<BinaryOpNode>();
const char* fname = n->op->FunctionName();
if (fname[1] == '\0' && !isalpha(fname[0])) {
os << '(';
n->lhs.Print(os);
os << ' ' << fname[0] << ' ';
n->rhs.Print(os);
os << ')';
} else {
os << fname << '(';
n->lhs.Print(os);
os << ", ";
n->rhs.Print(os);
os << ')';
}
return;
}
case kUnaryOpNode: {
const auto* n = Get<UnaryOpNode>();
os << n->op->FunctionName() << '(';
n->src.Print(os);
os << ')';
return;
}
case kReduceNode: {
const auto* n = Get<ReduceNode>();
os << "reduce("<< n->op->FunctionName() << ", ";
n->src.Print(os);
os << ", " << n->rdom << ')';
return;
}
case kTensorReadNode: {
const auto* n = Get<TensorReadNode>();
os << n->tensor.name() << n->indices;
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
}
}
} // namespace tvm } // namespace tvm
...@@ -21,6 +21,15 @@ TEST(Expr, Reduction) { ...@@ -21,6 +21,15 @@ TEST(Expr, Reduction) {
CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))"); CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))");
} }
TEST(Expr, Simplify) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, x + 10) * 100;
std::ostringstream os;
os << Simplify(z);
CHECK(os.str() == "((x * 100) + 1000)");
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
from tvm import cpp as tvm from tvm import cpp as tvm
def test_basic(): def test_basic():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
c = a + b c = a + b
assert a == c.lhs assert a == c.lhs
assert c.dtype == tvm.int32 assert c.dtype == tvm.int32
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
...@@ -13,11 +13,29 @@ def test_basic(): ...@@ -13,11 +13,29 @@ def test_basic():
def test_array(): def test_array():
a = tvm.Var('a') a = tvm.Var('a')
x = tvm.function._symbol([1,2,a]) x = tvm.function._symbol([1,2,a])
print type(x)
print len(x)
print x[4] def assert_equal(x, y):
z = tvm.simplify(x - y)
assert isinstance(z, tvm.expr.IntExpr)
assert z.value == 0
def test_simplify():
a = tvm.Var('a')
b = tvm.Var('b')
e1 = a * (2 + 1) + b * 1
e2 = a * (2 + 1) - b * 1
e3 = tvm.max(a * 3 + 5, 3 + 3 * a)
e4 = a - a
assert_equal(e1, a * 3 + b)
assert_equal(e2, a * 3 - b)
assert_equal(e3, a * 3 + 5)
assert_equal(e4, 0)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_array() test_array()
test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment