Commit 5b408d1d by Tianqi Chen Committed by GitHub

[IR] Move AttrStmt to HalideIR (#21)

parent 383494a5
Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438
Subproject commit b6637f611f91dd075dc251438f72ad38901d17fb
......@@ -49,34 +49,6 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This is used to insert hint(shape, storage, split) about certain scopes.
*/
struct AttrStmt : public StmtNode<AttrStmt> {
/*! \brief this is attribute about certain node */
NodeRef node;
/*! \brief the type key of the attribute */
std::string type_key;
/*! \brief The attribute value, value is well defined at current scope. */
Expr value;
/*! \brief The body statement to be executed */
Stmt body;
/*! \brief construct expr from name and rdom */
static Stmt make(NodeRef node, std::string type_key, Expr value, Stmt body);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("node", &node);
v->Visit("type_key", &type_key);
v->Visit("value", &value);
v->Visit("body", &body);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "AttrStmt";
};
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
......@@ -106,6 +78,7 @@ using Halide::Internal::Broadcast;
using Halide::Internal::Call;
using Halide::Internal::Let;
using Halide::Internal::LetStmt;
using Halide::Internal::AttrStmt;
using Halide::Internal::AssertStmt;
using Halide::Internal::ProducerConsumer;
using Halide::Internal::For;
......
......@@ -18,8 +18,6 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode;
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
os << args.at(0).operator NodeRef();
......
......@@ -16,7 +16,6 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
......
......@@ -20,11 +20,6 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
}
template<>
void StmtNode<AttrStmt>::accept(IRVisitor *v, const Stmt&) const {
LOG(FATAL) << "AttrStmt do not work with old Visitor, use IRFunctor style visitor";
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce("
......@@ -34,15 +29,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ", rdom=" << op->rdom << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
});
} // namespace Internal
} // namespace Halide
......@@ -62,15 +48,6 @@ Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
return Expr(n);
}
Stmt AttrStmt::make(NodeRef node, std::string type_key, Expr value, Stmt body) {
auto n = std::make_shared<AttrStmt>();
n->node = node;
n->type_key = type_key;
n->value = value;
n->body = body;
return Stmt(n);
}
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
......
......@@ -16,11 +16,6 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}
// namespace to register the functors.
namespace {
using namespace Halide::Internal;
// const expr
inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
return e;
......@@ -290,7 +285,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return s;
})
.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) {
Region new_bounds;
Halide::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
......@@ -350,7 +345,5 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return Evaluate::make(v);
}
});
} // namespace
} // namespace ir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment