Commit 28876530 by tqchen

Add AttrStmt

parent 61de73b4
Subproject commit 7f1d811972bccc26f651ea2289d88bcadea9fe9f
Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9
......@@ -17,6 +17,7 @@ namespace tvm {
namespace ir {
using Halide::Internal::ExprNode;
using Halide::Internal::StmtNode;
using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
......@@ -47,6 +48,34 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This is used to insert hint(shape, storage, split) about certain scopes.
*/
struct AttrStmt : public StmtNode<AttrStmt> {
/*! \brief this is attribute about certain node */
NodeRef node;
/*! \brief the type key of the attribute */
std::string type_key;
/*! \brief The attribute value, value is well defined at current scope. */
Expr value;
/*! \brief The body statement to be executed */
Stmt body;
/*! \brief construct expr from name and rdom */
static Stmt make(NodeRef node, std::string type_key, Expr value, Stmt body);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("node", &node);
v->Visit("type_key", &type_key);
v->Visit("value", &value);
v->Visit("body", &body);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "AttrStmt";
};
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
......
......@@ -32,6 +32,7 @@ class ComputeOpNode : public OperationNode {
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
......
......@@ -38,42 +38,7 @@ class Schedule : public NodeRef {
inline const ScheduleNode* operator->() const;
};
/*! \brief schedule container */
class AttachSpec : public NodeRef {
public:
AttachSpec() {}
explicit AttachSpec(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const AttachSpecNode* operator->() const;
};
// defintion of node containers
/*! \brief The attach specification of each subschedule */
class AttachSpecNode : public Node {
public:
/*! \brief The attachment type */
AttachType attach_type;
/*!
* \brief The split to be attached to,
* only valid when attach_type is kRoot
*/
Split attach_split;
/*! \brief the child schedule to be attached. */
Schedule schedule;
const char* type_key() const final {
return "AttachSpec";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("attach_type", &attach_type);
v->Visit("attach_split", &attach_split);
v->Visit("schedule", &schedule);
}
};
/*! \brief represents the schedule of the tensor */
class ScheduleNode : public Node {
public:
......@@ -83,8 +48,17 @@ class ScheduleNode : public Node {
std::string scope;
/*! \brief Splits over iteration domains */
Array<Split> splits;
/*! \brief attach specifications */
Array<AttachSpec> attachs;
/*! \brief The attachment type of the schedule */
AttachType attach_type;
/*!
* \brief The attach point of this schedule, if it is a split
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
*/
Split attach_parent;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
// the type key
const char* type_key() const final {
return "Schedule";
}
......@@ -92,7 +66,9 @@ class ScheduleNode : public Node {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("splits", &splits);
v->Visit("attachs", &attachs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children);
}
};
......@@ -101,9 +77,5 @@ inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline const AttachSpecNode* AttachSpec::operator->() const {
return static_cast<const AttachSpecNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SCHEDULE_H_
......@@ -29,13 +29,6 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});
TVM_REGISTER_API(_make_Reduce)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Reduce::make(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
......@@ -54,22 +47,6 @@ TVM_REGISTER_API(_make_Allocate)
args.at(4));
});
TVM_REGISTER_API(_make_LetStmt)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 3) {
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2));
} else {
CHECK_EQ(args.size(), 5);
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
}
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
......@@ -89,6 +66,12 @@ TVM_REGISTER_API(_make_LetStmt)
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
......@@ -99,6 +82,9 @@ TVM_REGISTER_API(_make_LetStmt)
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE3(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
......@@ -123,6 +109,7 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
......
......@@ -18,10 +18,16 @@ namespace Halide {
namespace Internal {
using tvm::ir::Reduce;
using tvm::ir::AttrStmt;
template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet";
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
}
template<>
void StmtNode<AttrStmt>::accept(IRVisitor *v, const Stmt&) const {
LOG(FATAL) << "AttrStmt do not work with old Visitor, use IRFunctor style visitor";
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
......@@ -33,15 +39,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ", rdom=" << op->rdom << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
});
} // namespace Internal
} // namespace Halide
namespace tvm {
namespace ir {
// reduce
TVM_REGISTER_NODE_TYPE(Reduce);
Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
......@@ -52,9 +63,17 @@ Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
return Expr(n);
}
Stmt AttrStmt::make(NodeRef node, std::string type_key, Expr value, Stmt body) {
auto n = std::make_shared<AttrStmt>();
n->node = node;
n->type_key = type_key;
n->value = value;
n->body = body;
return Stmt(n);
}
// HalideIR node
using namespace Halide::Internal;
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
......
......@@ -74,8 +74,6 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
} // namespace tvm
......@@ -13,7 +13,6 @@ Schedule::Schedule(Operation op, std::string scope) {
node_ = n;
}
TVM_REGISTER_NODE_TYPE(AttachSpecNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
......@@ -19,11 +19,12 @@ class IRInline : public IRMutator {
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
return InlineCall(call);
} else {
return IRMutator::Mutate(expr);
return expr;
}
}
......
......@@ -72,6 +72,18 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<AttrStmt>([](const AttrStmt* op, const Stmt& s, IRMutator* m) {
Expr value = m->Mutate(op->value);
Stmt body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, op->value, op->body);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<IntImm>(ReturnSelfExpr)
.set_dispatch<UIntImm>(ReturnSelfExpr)
......
......@@ -66,6 +66,12 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
......
......@@ -13,11 +13,19 @@ namespace {
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(std::vector<Tensor> tensors)
: tensors_(tensors) {}
std::vector<Tensor> tensors_;
};
explicit InjectRealize(Schedule sch)
: sch_(sch) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
const For* op = stmt.as<For>();
return stmt;
}
private:
// the operations to be carried
Schedule sch_;
};
} // namespace
} // namespace ir
......
......@@ -22,10 +22,15 @@ def test_let():
x = tvm.Var('x')
y = tvm.Var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1), y, "stride")
assert stmt.attr_of_node == y
print(stmt)
x, 10, tvm.make.Evaluate(x + 1));
def test_attr():
x = tvm.Var('x')
y = tvm.Var('y')
stmt = tvm.make.AttrStmt(
y, "stride", 10, tvm.make.Evaluate(x + 1));
assert stmt.node == y
print(stmt)
def test_basic():
a = tvm.Var('a')
......@@ -44,6 +49,8 @@ def test_stmt():
if __name__ == "__main__":
test_attr()
test_const()
test_make()
test_ir()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment