Commit 13383928 by Haichen Shen

add var binding for expr

parent 816419be
......@@ -6,6 +6,8 @@
#ifndef TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#include <vector>
#include "./expr.h"
#include "./expr_node.h"
......@@ -19,6 +21,14 @@ namespace tvm {
Expr Simplify(Expr src);
/*!
* \brief replace the variables in expression src by specification from dict
* \param src The source expression
* \param dict The specification for variable replacement
* \return the new expression with variable replaced
*/
Expr Bind(Expr src, std::unordered_map<Expr, Expr> dict);
/*!
* \brief visit the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
......@@ -55,6 +65,47 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
fvisit(expr);
}
/*!
* \brief transform the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
* \return the new expression after transformation
*/
template<typename FVisit>
inline Expr Transform(const Expr& expr, FVisit fvisit) {
// TODO(tqchen) change to stack based impl.
std::vector<Expr> children;
switch (expr.node_type()) {
case kBinaryOpNode: {
const auto* n = expr.Get<BinaryOpNode>();
Expr e = Transform(n->lhs, fvisit);
children.push_back(e);
children.push_back(Transform(n->rhs, fvisit));
break;
}
case kUnaryOpNode: {
const auto* n = expr.Get<UnaryOpNode>();
children.push_back(Transform(n->src, fvisit));
break;
}
case kReduceNode: {
const auto* n = expr.Get<ReduceNode>();
children.push_back(Transform(n->src, fvisit));
break;
}
case kTensorReadNode: {
const auto* n = expr.Get<TensorReadNode>();
for (size_t i = 0; i < n->indices.size(); ++i) {
children.push_back(Transform(n->indices[i], fvisit));
}
break;
}
default: break;
}
Expr ret = fvisit(expr, children);
return ret;
}
} // namespace tvm
#endif // TVM_EXPR_UTIL_H_
......@@ -146,6 +146,68 @@ Expr Simplify(Expr src) {
return cexpr.AsExpr();
}
Expr ExprWithNewChildren(Expr src, std::vector<Expr> children) {
if (children.size()) {
switch (src.node_type()) {
case kBinaryOpNode: {
const auto* n = src.Get<BinaryOpNode>();
if (n->lhs == children[0] && n->rhs == children[0])
return src;
return (*n->op)(children[0], children[1]);
}
case kUnaryOpNode: {
const auto* n = src.Get<UnaryOpNode>();
if (n->src == children[0])
return src;
return (*n->op)(children[0]);
}
case kReduceNode: {
const auto* n = src.Get<ReduceNode>();
if (n->src == children[0])
return src;
return (n->op)->Reduce(children[0], n->rdom);
}
case kTensorReadNode: {
const auto* n = src.Get<TensorReadNode>();
bool same = true;
for (size_t i = 0; i < n->indices.size(); ++i) {
if (n->indices[i] != children[i]) {
same = false;
break;
}
}
if (same)
return src;
Array<Expr> indices(children);
return n->tensor(indices);
}
default: {
return src;
}
}
}
return src;
}
Expr Bind(Expr src, std::unordered_map<Expr, Expr> dict) {
auto replace = [&](Expr e, std::vector<Expr> children) {
switch (e.node_type()) {
case kVarNode: {
auto it = dict.find(e);
if (it != dict.end()) {
return it->second;
}
return e;
}
default: {
return ExprWithNewChildren(e, children);
}
}
};
return Transform(src, replace);
}
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
......
......@@ -13,6 +13,12 @@ DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
namespace tvm {
Expr UnaryOp::operator()(Expr src) const {
auto nptr = std::make_shared<UnaryOpNode>(this, std::move(src));
nptr->Verify();
return Expr(std::move(nptr));
}
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
auto nptr = std::make_shared<BinaryOpNode>(
this, std::move(lhs), std::move(rhs));
......
......@@ -30,6 +30,25 @@ TEST(Expr, Simplify) {
CHECK(os.str() == "((x * 100) + 1000)");
}
TEST(Expr, Bind) {
using namespace tvm;
Var x("x"), y("y"), z("z");
Var i("i"), j("j");
Tensor A({y, z}, "A");
Expr e1 = x * 5;
std::unordered_map<Expr, Expr> dict = {{x, y * 10 + z}};
std::ostringstream os1, os2;
os1 << Bind(e1, dict);
CHECK(os1.str() == "(((y * 10) + z) * 5)");
Expr e2 = A(i, j);
dict.clear();
dict[i] = 64 * x;
dict[j] = z + 16 * y;
os2 << Bind(e2, dict);
CHECK(os2.str() == "A[(64 * x), (z + (16 * y))]");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment