Commit 8e04361c by tqchen

Refactor IR Pass

parent ff6b8d82
Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b
Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
......@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
class Var : public Halide::VarExpr {
......
......@@ -29,7 +29,7 @@ class IRMutator {
* \brief mutate expression
* \return the mutated expr
*/
virtual Expr mutate(Expr expr) {
virtual Expr Mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this);
}
......@@ -37,7 +37,7 @@ class IRMutator {
* \brief mutate expression
* \return the mutated stmt
*/
virtual Stmt mutate(Stmt stmt) {
virtual Stmt Mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this);
}
......@@ -58,28 +58,21 @@ class IRMutator {
*/
class IRMutatorExample : public IRMutator {
public:
Expr mutate(Expr expr) final {
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt mutate(Stmt stmt) final {
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt));
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_PASS_H_
......@@ -24,7 +24,7 @@ class IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual void visit(const IRNodeRef& node) {
virtual void Visit(const IRNodeRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
......
......@@ -101,7 +101,7 @@ class Tensor : public FunctionRef {
};
/*! \brief Node to represent a tensor */
class TensorNode : public Node {
class TensorNode : public FunctionBaseNode {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
......@@ -125,6 +125,12 @@ class TensorNode : public Node {
v->Visit("dim_var", &dim_var);
v->Visit("source", &source);
}
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
......
......@@ -8,32 +8,6 @@
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).mutate(expr);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
......@@ -57,7 +31,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
bool changed = false;
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = m->mutate(old_elem);
Expr new_elem = m->Mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
......@@ -73,8 +47,8 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
bool changed = false;
for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i];
Expr new_min = m->mutate(r->min);
Expr new_extent = m->mutate(r->extent);
Expr new_min = m->Mutate(r->min);
Expr new_extent = m->Mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = Range::make_with_min_extent(new_min, new_extent);
......@@ -89,7 +63,7 @@ inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
RDomain new_rdom = MutateRDom(op->rdom, m);
Expr new_source = m->mutate(op->source);
Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) &&
op->source.same_as(new_source)) {
return e;
......@@ -107,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
......@@ -118,8 +92,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
// binary operator
template<typename T>
inline Expr Binary(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a);
Expr b = m->mutate(op->b);
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
......@@ -147,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a);
Expr a = m->Mutate(op->a);
if (a.same_as(op->a)) {
return e;
} else {
......@@ -155,9 +129,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->mutate(op->condition);
Expr t = m->mutate(op->true_value);
Expr f = m->mutate(op->false_value);
Expr cond = m->Mutate(op->condition);
Expr t = m->Mutate(op->true_value);
Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
......@@ -167,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
Expr index = m->mutate(op->index);
Expr index = m->Mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
......@@ -175,8 +149,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
Expr base = m->mutate(op->base);
Expr stride = m->mutate(op->stride);
Expr base = m->Mutate(op->base);
Expr stride = m->Mutate(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return e;
......@@ -185,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
......@@ -202,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
})
.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr body = m->mutate(op->body);
Expr value = m->Mutate(op->value);
Expr body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
......@@ -214,8 +188,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value);
Stmt body = m->mutate(op->body);
Expr value = m->Mutate(op->value);
Stmt body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
......@@ -224,8 +198,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition);
Expr message = m->mutate(op->message);
Expr condition = m->Mutate(op->condition);
Expr message = m->Mutate(op->message);
if (condition.same_as(op->condition) && message.same_as(op->message)) {
return s;
......@@ -234,7 +208,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) {
Stmt body = m->mutate(op->body);
Stmt body = m->Mutate(op->body);
if (body.same_as(op->body)) {
return s;
} else {
......@@ -242,9 +216,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
Expr min = m->mutate(op->min);
Expr extent = m->mutate(op->extent);
Stmt body = m->mutate(op->body);
Expr min = m->Mutate(op->min);
Expr extent = m->Mutate(op->extent);
Stmt body = m->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
......@@ -255,8 +229,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr index = m->mutate(op->index);
Expr value = m->Mutate(op->value);
Expr index = m->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
......@@ -276,14 +250,14 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->mutate(op->extents[i]));
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->mutate(op->body);
Expr condition = m->mutate(op->condition);
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->mutate(op->new_expr);
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
......@@ -308,16 +282,16 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->mutate(old_min);
Expr new_extent = m->mutate(old_extent);
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->mutate(op->body);
Expr condition = m->mutate(op->condition);
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
......@@ -328,8 +302,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->mutate(op->first);
Stmt rest = m->mutate(op->rest);
Stmt first = m->Mutate(op->first);
Stmt rest = m->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
......@@ -338,9 +312,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition);
Stmt then_case = m->mutate(op->then_case);
Stmt else_case = m->mutate(op->else_case);
Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->Mutate(op->else_case);
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
......@@ -350,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->mutate(op->value);
Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) {
return s;
} else {
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr Mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::Mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).Mutate(expr);
}
} // namespace ir
} // namespace tvm
......@@ -14,10 +14,10 @@ class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {}
void visit(const IRNodeRef& node) final {
void Visit(const IRNodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::visit(node);
IRVisitor::Visit(node);
f_(node);
}
......@@ -25,18 +25,18 @@ class IRApplyVisit : public IRVisitor {
std::function<void(const IRNodeRef&)> f_;
std::unordered_set<const Node*> visited_;
};
} // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
}
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
IRApplyVisit v(fvisit);
v.visit(node);
}
// namespace to register the functors.
namespace {
......@@ -47,22 +47,22 @@ void NoOp(const IRNodeRef& n, IRVisitor* v) {
inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) {
v->visit(arr[i]);
v->Visit(arr[i]);
}
}
inline void VisitRDom(RDomain rdom, IRVisitor* v) {
for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i];
v->visit(r->min);
v->visit(r->extent);
v->Visit(r->min);
v->Visit(r->extent);
}
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->visit(op->source);
v->Visit(op->source);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
......@@ -74,14 +74,14 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
v->visit(op->value);
v->Visit(op->value);
});
// binary operator
template<typename T>
inline void Binary(const T* op, IRVisitor* v) {
v->visit(op->a);
v->visit(op->b);
v->Visit(op->a);
v->Visit(op->b);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
......@@ -103,51 +103,51 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Not>([](const Not* op, IRVisitor* v) {
v->visit(op->a);
v->Visit(op->a);
})
.set_dispatch<Select>([](const Select *op, IRVisitor* v) {
v->visit(op->condition);
v->visit(op->true_value);
v->visit(op->false_value);
v->Visit(op->condition);
v->Visit(op->true_value);
v->Visit(op->false_value);
})
.set_dispatch<Load>([](const Load *op, IRVisitor* v) {
v->visit(op->index);
v->Visit(op->index);
})
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->visit(op->base);
v->visit(op->stride);
v->Visit(op->base);
v->Visit(op->stride);
})
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->visit(op->value);
v->Visit(op->value);
})
.set_dispatch<Call>([](const Call *op, IRVisitor* v) {
VisitArray(op->args, v);
})
.set_dispatch<Let>([](const Let *op, IRVisitor* v) {
v->visit(op->value);
v->visit(op->body);
v->Visit(op->value);
v->Visit(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
v->visit(op->value);
v->visit(op->body);
v->Visit(op->value);
v->Visit(op->body);
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->visit(op->condition);
v->visit(op->message);
v->Visit(op->condition);
v->Visit(op->message);
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->visit(op->body);
v->Visit(op->body);
})
.set_dispatch<For>([](const For *op, IRVisitor* v) {
v->visit(op->min);
v->visit(op->extent);
v->visit(op->body);
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
})
.set_dispatch<Store>([](const Store *op, IRVisitor* v) {
v->visit(op->value);
v->visit(op->index);
v->Visit(op->value);
v->Visit(op->index);
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
......@@ -155,36 +155,36 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) {
v->visit(op->extents[i]);
v->Visit(op->extents[i]);
}
v->visit(op->body);
v->visit(op->condition);
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->visit(op->new_expr);
v->Visit(op->new_expr);
}
})
.set_dispatch<Free>(NoOp)
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
v->visit(op->bounds[i]->min);
v->visit(op->bounds[i]->extent);
v->Visit(op->bounds[i]->min);
v->Visit(op->bounds[i]->extent);
}
v->visit(op->body);
v->visit(op->condition);
v->Visit(op->body);
v->Visit(op->condition);
})
.set_dispatch<Block>([](const Block *op, IRVisitor* v) {
v->visit(op->first);
v->visit(op->rest);
v->Visit(op->first);
v->Visit(op->rest);
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) {
v->visit(op->condition);
v->visit(op->then_case);
v->visit(op->else_case);
v->Visit(op->condition);
v->Visit(op->then_case);
v->Visit(op->else_case);
})
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
v->visit(op->value);
v->Visit(op->value);
});
} // namespace
......
......@@ -13,10 +13,10 @@ class IRVar2Const : public IRMutator {
public:
VarExpr var;
int int_val;
Expr mutate(Expr expr) final {
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
f(expr, expr, this) : IRMutator::Mutate(expr));
}
static FMutateExpr &vtable_expr();
};
......@@ -46,31 +46,12 @@ TEST(IRMutator, Basic) {
IRVar2Const mu;
mu.var = y;
mu.int_val = 10;
auto zz = mu.mutate(z);
auto zz = mu.Mutate(z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 10)");
}
TEST(IRMutator, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST(IRPass, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = ir::Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = ir::Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -2,6 +2,7 @@
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
using namespace Halide::Internal;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment