Commit 357ad592 by tqchen

Fix Schedule structure, refactor compute to all rely on iter var

parent 3a48b323
...@@ -133,13 +133,13 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*) ...@@ -133,13 +133,13 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
*/ */
class IterVarNode : public Node { class IterVarNode : public Node {
public: public:
/*! \brief The looping variable */
Var var;
/*! /*!
* \brief the domain of iteration, if known, can be None * \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule. * For the intermediate schedule node, before schedule.
*/ */
Range dom; Range dom;
/*! \brief The looping variable */
Var var;
/*! /*!
* \brief additional tag on the iteration variable, * \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag. * set this if this is binded already to a known thread tag.
...@@ -147,12 +147,13 @@ class IterVarNode : public Node { ...@@ -147,12 +147,13 @@ class IterVarNode : public Node {
std::string thread_tag; std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("dom", &dom); v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("thread_tag", &thread_tag); v->Visit("thread_tag", &thread_tag);
} }
static IterVar make(Var var, Range dom, std::string thread_tag); static IterVar make(Range dom, Var var, std::string thread_tag);
static constexpr const char* _type_key = "IterVar"; static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode); TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
}; };
......
...@@ -17,6 +17,8 @@ namespace tvm { ...@@ -17,6 +17,8 @@ namespace tvm {
*/ */
class ComputeOpNode : public OperationNode { class ComputeOpNode : public OperationNode {
public: public:
/*! \brief Iteration variables over the dimensions */
Array<IterVar> dim_var;
/*! \brief the compute expression */ /*! \brief the compute expression */
Expr body; Expr body;
/*! \brief constructor */ /*! \brief constructor */
...@@ -25,19 +27,18 @@ class ComputeOpNode : public OperationNode { ...@@ -25,19 +27,18 @@ class ComputeOpNode : public OperationNode {
size_t num_outputs() const final { size_t num_outputs() const final {
return 1; return 1;
} }
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final; std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("dim_var", &dim_var); v->Visit("dim_var", &dim_var);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(Domain domain, static Operation make(std::string name,
std::string name, Array<IterVar> dim_var,
Array<Var> dim_var,
Expr body); Expr body);
static constexpr const char* _type_key = "ComputeOp"; static constexpr const char* _type_key = "ComputeOp";
......
...@@ -8,15 +8,14 @@ ...@@ -8,15 +8,14 @@
#include <string> #include <string>
#include "./base.h" #include "./base.h"
#include "./split.h"
#include "./operation.h" #include "./operation.h"
namespace tvm { namespace tvm {
// Node container for Schedule // Node container for Schedule
class ScheduleNode; class ScheduleNode;
// Node container for AttachSpec // Node container for IterVarRelation
class AttachSpecNode; class IterVarRelationNode;
/*! \brief the attachment type */ /*! \brief the attachment type */
enum AttachType : int { enum AttachType : int {
...@@ -38,42 +37,132 @@ class Schedule : public NodeRef { ...@@ -38,42 +37,132 @@ class Schedule : public NodeRef {
inline const ScheduleNode* operator->() const; inline const ScheduleNode* operator->() const;
}; };
/*!
* \brief The schedule relation between IterVars
* can be Split, Fuse.
*/
class IterVarRelation : public NodeRef {
public:
IterVarRelation() {}
explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarRelationNode* operator->() const;
};
// defintion of node containers // defintion of node containers
/*! \brief represents the schedule of the tensor */ /*!
* \brief represents the schedule of the tensor
*
* A schedule is a Directed acylic hypergraph.
* With each node is represented by a IterVar,
* and each hyper-edge is represented by a IterVarRelation.
*
* The relations can be Split/Fuse.
*
* The current data structure stores the hyper graph in its
* bipartite representation.
*
* The relations connects the IterVars in the graph.
*/
class ScheduleNode : public Node { class ScheduleNode : public Node {
public: public:
/*! \brief The operation to be scheduled */ /*! \brief The operation to be scheduled */
Operation op; Operation op;
/*! \brief The thread scope level of the schedule */ /*! \brief The thread scope level of the schedule */
std::string scope; std::string scope;
/*! \brief Splits over iteration domains */ /*! \brief All the nodes in the iter var */
Array<Split> splits; Array<IterVar> all_iter_vars;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */ /*! \brief The attachment type of the schedule */
AttachType attach_type; AttachType attach_type;
/*! /*!
* \brief The attach point of this schedule, if it is a split * \brief The attach point of this schedule.
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
*/ */
Split attach_parent; IterVar attach_parent;
/*! \brief the schedules that this schedule depend on */ /*! \brief the schedules that this schedule depend on */
Array<Schedule> children; Array<Schedule> children;
// the type key
const char* type_key() const final {
return "Schedule";
}
const uint32_t type_index() const final {
static uint32_t tidx = TypeKey2Index(type_key());
return tidx;
}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("splits", &splits); v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent); v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children); v->Visit("children", &children);
} }
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
};
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
};
/*!
* \brief Split the parent domain into product of
* outer and iter.
*/
class SplitNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The outer domain */
IterVar outer;
/*! \brief The inner domain */
IterVar inner;
/*! \brief The split factor */
Expr factor;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("factor", &factor);
}
static IterVarRelation make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode);
};
/*!
* \brief Fuse two domains into one domain.
*/
class FuseNode : public IterVarRelationNode {
public:
/*! \brief The outer domain */
IterVar outer;
/*! \brief The inner domain */
IterVar inner;
/*! \brief The target domain */
IterVar fused;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("fused", &fused);
}
static IterVarRelation make(
IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
}; };
// implementations // implementations
...@@ -81,5 +170,9 @@ inline const ScheduleNode* Schedule::operator->() const { ...@@ -81,5 +170,9 @@ inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get()); return static_cast<const ScheduleNode*>(node_.get());
} }
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
}
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_H_ #endif // TVM_SCHEDULE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file split.h
* \brief Define a split over Domain or RDomain
*/
#ifndef TVM_SPLIT_H_
#define TVM_SPLIT_H_
#include "./base.h"
#include "./expr.h"
namespace tvm {
// internal node container for split.
class SplitNode;
/*! \brief Split over input domain */
class Split : public NodeRef {
public:
/*! \brief default constructor */
Split() {}
explicit Split(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SplitNode* operator->() const;
};
/*!
* \brief base class of split node,
* specifies a split over domain
* split also defines how to generate
*/
class SplitNode : public Node {
public:
/*! \brief the variable to be splitted on */
Var var;
};
/*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode {
public:
/*! \brief The factor of the split */
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("factor", &factor);
}
static Split make(Var var, Expr factor);
static constexpr const char* _type_key = "DimSplit";
TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode);
};
// Implementations of inline functions
inline const SplitNode* Split::operator->() const {
return static_cast<const SplitNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SPLIT_H_
...@@ -174,12 +174,10 @@ class TensorNode : public FunctionBaseNode { ...@@ -174,12 +174,10 @@ class TensorNode : public FunctionBaseNode {
*/ */
class OperationNode : public Node { class OperationNode : public Node {
public: public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */ /*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0; virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */ /*! \return name of i-th output */
......
...@@ -83,11 +83,11 @@ def compute(shape, fcompute, name="TensorCompute"): ...@@ -83,11 +83,11 @@ def compute(shape, fcompute, name="TensorCompute"):
arg_names = fcompute.__code__.co_varnames arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names): if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension") raise ValueError("fcompute do not match dimension")
dim_var = [Var(x) for x in arg_names]
body = fcompute(*dim_var) dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
dom = [Range(0, x) for x in shape] body = fcompute(*[v.var for v in dim_var])
op_node = _function_internal._ComputeOp( op_node = _function_internal._ComputeOp(
dom, name, dim_var, body) name, dim_var, body)
return _function_internal._Tensor( return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0) shape, name, body.dtype, op_node, 0)
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/split.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
...@@ -89,8 +88,7 @@ TVM_REGISTER_API(_ComputeOp) ...@@ -89,8 +88,7 @@ TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0), *ret = ComputeOpNode::make(args.at(0),
args.at(1), args.at(1),
args.at(2), args.at(2));
args.at(3));
}); });
...@@ -100,11 +98,6 @@ TVM_REGISTER_API(_IterVar) ...@@ -100,11 +98,6 @@ TVM_REGISTER_API(_IterVar)
}); });
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1));
});
TVM_REGISTER_API(_Schedule) TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1)); *ret = Schedule(args.at(0), args.at(1));
......
...@@ -24,12 +24,12 @@ Range Range::make_with_min_extent(Expr min, Expr extent) { ...@@ -24,12 +24,12 @@ Range Range::make_with_min_extent(Expr min, Expr extent) {
} }
IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag) IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag)
: IterVar(IterVarNode::make(Var(var_name, Int(32)), dom, thread_tag)) {} : IterVar(IterVarNode::make(dom, Var(var_name, Int(32)), thread_tag)) {}
IterVar IterVarNode::make(Var var, Range dom, std::string thread_tag) { IterVar IterVarNode::make(Range dom, Var var, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>(); std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
n->var = var;
n->dom = dom; n->dom = dom;
n->var = var;
n->thread_tag = thread_tag; n->thread_tag = thread_tag;
return IterVar(n); return IterVar(n);
} }
......
...@@ -13,32 +13,25 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { ...@@ -13,32 +13,25 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
std::vector<Var> dim_index; std::vector<IterVar> dim_var;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os; std::ostringstream os;
os << "dim_var" << i; os << "dim_var" << i;
dim_index.push_back(Var(os.str())); dim_var.push_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(dim_var.back()->var);
} }
std::vector<Range> dom; op_node->dim_var = Array<IterVar>(dim_var);
for (size_t i = 0; i < ndim; ++i) { op_node->body = fcompute(args);
dom.push_back(Range(0, shape[i]));
}
op_node->dim_var = Array<Var>(dim_index);
op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->dim_var);
op_node->name = name; op_node->name = name;
return Operation(op_node).output(0); return Operation(op_node).output(0);
} }
Operation ComputeOpNode::make(Domain domain, Operation ComputeOpNode::make(std::string name,
std::string name, Array<IterVar> dim_var,
Array<Var> dim_var,
Expr body) { Expr body) {
auto n = std::make_shared<ComputeOpNode>(); auto n = std::make_shared<ComputeOpNode>();
n->domain = domain;
n->name = name; n->name = name;
n->dim_var = dim_var; n->dim_var = dim_var;
n->body = body; n->body = body;
...@@ -55,6 +48,10 @@ Tensor Operation::output(size_t i) const { ...@@ -55,6 +48,10 @@ Tensor Operation::output(size_t i) const {
return Tensor(node); return Tensor(node);
} }
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return dim_var;
}
std::string ComputeOpNode::output_name(size_t i) const { std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0); CHECK_EQ(i, 0);
return name; return name;
...@@ -68,8 +65,9 @@ Type ComputeOpNode::output_dtype(size_t i) const { ...@@ -68,8 +65,9 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array<Expr> ComputeOpNode::output_shape(size_t i) const { Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0); CHECK_EQ(i, 0);
std::vector<Expr> shape; std::vector<Expr> shape;
for (size_t i = 0; i < domain.size(); ++i) { for (size_t i = 0; i < dim_var.size(); ++i) {
shape.push_back(domain[i]->extent); const Range& r = dim_var[i]->dom;
shape.push_back(r->extent);
} }
return Array<Expr>(shape); return Array<Expr>(shape);
} }
......
...@@ -13,6 +13,28 @@ Schedule::Schedule(Operation op, std::string scope) { ...@@ -13,6 +13,28 @@ Schedule::Schedule(Operation op, std::string scope) {
node_ = n; node_ = n;
} }
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
auto n = std::make_shared<SplitNode>();
n->parent = parent;
n->outer = outer;
n->inner = inner;
n->factor = factor;
return IterVarRelation(n);
}
IterVarRelation FuseNode::make(
IterVar outer, IterVar inner, IterVar fused) {
auto n = std::make_shared<FuseNode>();
n->outer = outer;
n->inner = inner;
n->fused = fused;
return IterVarRelation(n);
}
TVM_REGISTER_NODE_TYPE(ScheduleNode); TVM_REGISTER_NODE_TYPE(ScheduleNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file split.cc
*/
#include <tvm/split.h>
namespace tvm {
Split DimSplitNode::make(Var var,
Expr factor) {
auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1);
n->var = var;
n->factor = factor;
return Split(n);
}
TVM_REGISTER_NODE_TYPE(DimSplitNode);
} // namespace tvm
...@@ -53,7 +53,8 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) { ...@@ -53,7 +53,8 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
if (!r->min.same_as(new_min)) changed = true; if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true; if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make( new_dom[i] = IterVarNode::make(
v->var, Range::make_with_min_extent(new_min, new_extent), v->thread_tag); Range::make_with_min_extent(new_min, new_extent),
v->var, v->thread_tag);
} }
if (!changed) { if (!changed) {
return rdom; return rdom;
......
...@@ -38,32 +38,6 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) { ...@@ -38,32 +38,6 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
return body; return body;
} }
void MakeLoop(const DimSplitNode* op,
const Split& s,
Scope<AttrKey, Expr>* pscope,
std::vector<Stmt>* nest) {
auto& scope = *pscope;
Expr out_min = scope[{op->var, "min"}];
Expr out_ext = scope[{op->var, "extent"}];
Expr stride = op->factor;
Var offset(s->var->name_hint + ".offset", Int(32));
// for loop with stride
// TODO(tqchen) split the loop to deal with tails
nest->emplace_back(
For::make(
offset, out_min, out_ext,
ForType::Parallel, DeviceAPI::None, Stmt()));
Expr in_min = offset + out_min;
Expr in_ext = min(stride, out_ext - offset);
// declare min and extent of the corresponding variable
nest->emplace_back(AttrStmt::make(op->var, "min", in_min, Stmt()));
nest->emplace_back(AttrStmt::make(op->var, "extent", in_ext, Stmt()));
// declare this is the loop
nest->emplace_back(AttrStmt::make(s, "split", 0, Stmt()));
// setup the scope.
pscope->Push({op->var, "min"}, in_min);
pscope->Push({op->var, "extent"}, in_ext);
}
Stmt MakePipeline(const Schedule& sch, Stmt body) { Stmt MakePipeline(const Schedule& sch, Stmt body) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment