Commit 357ad592 by tqchen

Fix Schedule structure, refactor compute to all rely on iter var

parent 3a48b323
......@@ -133,13 +133,13 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
*/
class IterVarNode : public Node {
public:
/*! \brief The looping variable */
Var var;
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range dom;
/*! \brief The looping variable */
Var var;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
......@@ -147,12 +147,13 @@ class IterVarNode : public Node {
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("thread_tag", &thread_tag);
}
static IterVar make(Var var, Range dom, std::string thread_tag);
static IterVar make(Range dom, Var var, std::string thread_tag);
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
};
......
......@@ -17,6 +17,8 @@ namespace tvm {
*/
class ComputeOpNode : public OperationNode {
public:
/*! \brief Iteration variables over the dimensions */
Array<IterVar> dim_var;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -25,19 +27,18 @@ class ComputeOpNode : public OperationNode {
size_t num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
v->Visit("dim_var", &dim_var);
v->Visit("body", &body);
}
static Operation make(Domain domain,
std::string name,
Array<Var> dim_var,
static Operation make(std::string name,
Array<IterVar> dim_var,
Expr body);
static constexpr const char* _type_key = "ComputeOp";
......
......@@ -8,15 +8,14 @@
#include <string>
#include "./base.h"
#include "./split.h"
#include "./operation.h"
namespace tvm {
// Node container for Schedule
class ScheduleNode;
// Node container for AttachSpec
class AttachSpecNode;
// Node container for IterVarRelation
class IterVarRelationNode;
/*! \brief the attachment type */
enum AttachType : int {
......@@ -38,42 +37,132 @@ class Schedule : public NodeRef {
inline const ScheduleNode* operator->() const;
};
/*!
* \brief The schedule relation between IterVars
* can be Split, Fuse.
*/
class IterVarRelation : public NodeRef {
public:
IterVarRelation() {}
explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarRelationNode* operator->() const;
};
// defintion of node containers
/*! \brief represents the schedule of the tensor */
/*!
* \brief represents the schedule of the tensor
*
* A schedule is a Directed acylic hypergraph.
* With each node is represented by a IterVar,
* and each hyper-edge is represented by a IterVarRelation.
*
* The relations can be Split/Fuse.
*
* The current data structure stores the hyper graph in its
* bipartite representation.
*
* The relations connects the IterVars in the graph.
*/
class ScheduleNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the schedule */
std::string scope;
/*! \brief Splits over iteration domains */
Array<Split> splits;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */
AttachType attach_type;
/*!
* \brief The attach point of this schedule, if it is a split
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
* \brief The attach point of this schedule.
*/
Split attach_parent;
IterVar attach_parent;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
// the type key
const char* type_key() const final {
return "Schedule";
}
const uint32_t type_index() const final {
static uint32_t tidx = TypeKey2Index(type_key());
return tidx;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("splits", &splits);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children);
}
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
};
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
};
/*!
* \brief Split the parent domain into product of
* outer and iter.
*/
class SplitNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The outer domain */
IterVar outer;
/*! \brief The inner domain */
IterVar inner;
/*! \brief The split factor */
Expr factor;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("factor", &factor);
}
static IterVarRelation make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode);
};
/*!
* \brief Fuse two domains into one domain.
*/
class FuseNode : public IterVarRelationNode {
public:
/*! \brief The outer domain */
IterVar outer;
/*! \brief The inner domain */
IterVar inner;
/*! \brief The target domain */
IterVar fused;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("fused", &fused);
}
static IterVarRelation make(
IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
};
// implementations
......@@ -81,5 +170,9 @@ inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SCHEDULE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file split.h
* \brief Define a split over Domain or RDomain
*/
#ifndef TVM_SPLIT_H_
#define TVM_SPLIT_H_
#include "./base.h"
#include "./expr.h"
namespace tvm {
// internal node container for split.
class SplitNode;
/*! \brief Split over input domain */
class Split : public NodeRef {
public:
/*! \brief default constructor */
Split() {}
explicit Split(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SplitNode* operator->() const;
};
/*!
* \brief base class of split node,
* specifies a split over domain
* split also defines how to generate
*/
class SplitNode : public Node {
public:
/*! \brief the variable to be splitted on */
Var var;
};
/*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode {
public:
/*! \brief The factor of the split */
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("factor", &factor);
}
static Split make(Var var, Expr factor);
static constexpr const char* _type_key = "DimSplit";
TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode);
};
// Implementations of inline functions
inline const SplitNode* Split::operator->() const {
return static_cast<const SplitNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SPLIT_H_
......@@ -174,12 +174,10 @@ class TensorNode : public FunctionBaseNode {
*/
class OperationNode : public Node {
public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief optional name of the operation */
std::string name;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
......
......@@ -83,11 +83,11 @@ def compute(shape, fcompute, name="TensorCompute"):
arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension")
dim_var = [Var(x) for x in arg_names]
body = fcompute(*dim_var)
dom = [Range(0, x) for x in shape]
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
op_node = _function_internal._ComputeOp(
dom, name, dim_var, body)
name, dim_var, body)
return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0)
......
......@@ -5,7 +5,6 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include "./c_api_registry.h"
......@@ -89,8 +88,7 @@ TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3));
args.at(2));
});
......@@ -100,11 +98,6 @@ TVM_REGISTER_API(_IterVar)
});
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1));
});
TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1));
......
......@@ -24,12 +24,12 @@ Range Range::make_with_min_extent(Expr min, Expr extent) {
}
IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag)
: IterVar(IterVarNode::make(Var(var_name, Int(32)), dom, thread_tag)) {}
: IterVar(IterVarNode::make(dom, Var(var_name, Int(32)), thread_tag)) {}
IterVar IterVarNode::make(Var var, Range dom, std::string thread_tag) {
IterVar IterVarNode::make(Range dom, Var var, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
n->var = var;
n->dom = dom;
n->var = var;
n->thread_tag = thread_tag;
return IterVar(n);
}
......
......@@ -13,32 +13,25 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<Var> dim_index;
std::vector<IterVar> dim_var;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_var" << i;
dim_index.push_back(Var(os.str()));
dim_var.push_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(dim_var.back()->var);
}
std::vector<Range> dom;
for (size_t i = 0; i < ndim; ++i) {
dom.push_back(Range(0, shape[i]));
}
op_node->dim_var = Array<Var>(dim_index);
op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->dim_var);
op_node->dim_var = Array<IterVar>(dim_var);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(Domain domain,
std::string name,
Array<Var> dim_var,
Operation ComputeOpNode::make(std::string name,
Array<IterVar> dim_var,
Expr body) {
auto n = std::make_shared<ComputeOpNode>();
n->domain = domain;
n->name = name;
n->dim_var = dim_var;
n->body = body;
......@@ -55,6 +48,10 @@ Tensor Operation::output(size_t i) const {
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return dim_var;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0);
return name;
......@@ -68,8 +65,9 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0);
std::vector<Expr> shape;
for (size_t i = 0; i < domain.size(); ++i) {
shape.push_back(domain[i]->extent);
for (size_t i = 0; i < dim_var.size(); ++i) {
const Range& r = dim_var[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
......
......@@ -13,6 +13,28 @@ Schedule::Schedule(Operation op, std::string scope) {
node_ = n;
}
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
auto n = std::make_shared<SplitNode>();
n->parent = parent;
n->outer = outer;
n->inner = inner;
n->factor = factor;
return IterVarRelation(n);
}
IterVarRelation FuseNode::make(
IterVar outer, IterVar inner, IterVar fused) {
auto n = std::make_shared<FuseNode>();
n->outer = outer;
n->inner = inner;
n->fused = fused;
return IterVarRelation(n);
}
TVM_REGISTER_NODE_TYPE(ScheduleNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file split.cc
*/
#include <tvm/split.h>
namespace tvm {
Split DimSplitNode::make(Var var,
Expr factor) {
auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1);
n->var = var;
n->factor = factor;
return Split(n);
}
TVM_REGISTER_NODE_TYPE(DimSplitNode);
} // namespace tvm
......@@ -53,7 +53,8 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make(
v->var, Range::make_with_min_extent(new_min, new_extent), v->thread_tag);
Range::make_with_min_extent(new_min, new_extent),
v->var, v->thread_tag);
}
if (!changed) {
return rdom;
......
......@@ -38,32 +38,6 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
return body;
}
void MakeLoop(const DimSplitNode* op,
const Split& s,
Scope<AttrKey, Expr>* pscope,
std::vector<Stmt>* nest) {
auto& scope = *pscope;
Expr out_min = scope[{op->var, "min"}];
Expr out_ext = scope[{op->var, "extent"}];
Expr stride = op->factor;
Var offset(s->var->name_hint + ".offset", Int(32));
// for loop with stride
// TODO(tqchen) split the loop to deal with tails
nest->emplace_back(
For::make(
offset, out_min, out_ext,
ForType::Parallel, DeviceAPI::None, Stmt()));
Expr in_min = offset + out_min;
Expr in_ext = min(stride, out_ext - offset);
// declare min and extent of the corresponding variable
nest->emplace_back(AttrStmt::make(op->var, "min", in_min, Stmt()));
nest->emplace_back(AttrStmt::make(op->var, "extent", in_ext, Stmt()));
// declare this is the loop
nest->emplace_back(AttrStmt::make(s, "split", 0, Stmt()));
// setup the scope.
pscope->Push({op->var, "min"}, in_min);
pscope->Push({op->var, "extent"}, in_ext);
}
Stmt MakePipeline(const Schedule& sch, Stmt body) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment