Commit 8f51c5fd by Tianqi Chen Committed by GitHub

[SCHEDULE] Add group, refactor thread bind api. (#82)

* [SCHEDULE] Add group, refactor thread bind api.

* fix doc

* fix g++-4.8

* More testscase

* Remove graph context from fix pt analysis
parent 6268e183
Subproject commit ce80d58741688b200f498fed8c7b0ea33e0516c8
Subproject commit 59fdca16978b6184bab87fbff7a00c95f1804686
......@@ -32,17 +32,6 @@ struct TensorDom {
};
/*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public FunctionBaseNode {
......@@ -102,13 +91,11 @@ class OperationNode : public FunctionBaseNode {
* Set the range of each root_iter_vars in the op to out_dom_map
*
* \param self The reference to self.
* \param graph_ctx The global graph context information.
* \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted.
*/
virtual void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*!
......@@ -162,7 +149,6 @@ class PlaceholderOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -214,7 +200,6 @@ class ComputeOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -253,6 +238,11 @@ class ScanOpNode : public OperationNode {
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief the inputs to the scan, these are optionally provided
* But they can be helpful to provide hints to speedup get of scan body.
*/
Array<Tensor> inputs;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
......@@ -279,7 +269,6 @@ class ScanOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -296,13 +285,15 @@ class ScanOpNode : public OperationNode {
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder);
Array<Tensor> state_placeholder,
Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
......@@ -339,7 +330,6 @@ class ExternOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -388,16 +378,19 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
* \brief Construct new tensors by scan.
*
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
*/
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan");
// same as compute, specialized for different fcompute function
......
......@@ -139,12 +139,20 @@ inline bool TVMArgValue::IsNodeType() const {
// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) {
if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
}
return *this;
}
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
}
return *this;
}
......
......@@ -8,6 +8,7 @@
#include <string>
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
namespace tvm {
......@@ -23,8 +24,7 @@ class IterVarAttrNode;
/*! \brief the attachment type */
enum AttachType : int {
kNone = 0,
kRoot = 1,
kGroupRoot = 1,
kInline = 2,
kInlinedAlready = 3,
kScope = 4,
......@@ -64,44 +64,50 @@ class Stage : public NodeRef {
*/
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline, attach it at parent.
* \brief Compute the function inline.
* \return reference to self.
*/
Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at root, attach it to its parent.
* \brief Compute the function at group root.
* \return reference to self.
*/
Stage& compute_root(); // NOLINT(*)
/*!
* \brief Rebase the parent iter var as rebased variable.
* \brief Bind the ivar to thread index.
*
* \param parent The parent iteration domain.
* \param rebased The variable to be used in rebase.
* \param ivar The IterVar to be binded.
* \param thread_ivar The thread axis to be binded.
* \return reference to self.
*/
Stage& rebase(IterVar parent, IterVar rebased);
Stage& bind(IterVar ivar, IterVar thread_ivar);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
* \param threads The threads to be launched around the scope.
* \note Each thread can only appear in one env_threads.
* \return reference to self.
*/
Stage& env_threads(Array<IterVar> threads);
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
* \param factor The split factor of the loop.
* \param p_outer The result outer domain
* \param p_inner The result inner domain.
* \param factor The split factor of the loop.
* \return reference to self.
*/
Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Split the iteration with a given outer domain,
* the outer domain must have a thread-tag.
* \brief Split the iteration with given number of parts.
*
* \param parent The parent domain.
* \param outer The outer domain to be spliited, must have a thread_tag.
* \param nparts The number of parts in the outer domain.
* \param p_outer The result outer domain.
* \param p_inner The result inner domain.
* \param factor Optional, the factor of the split,
* factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self.
*/
Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused
......@@ -123,25 +129,18 @@ class Stage : public NodeRef {
*
* \param x_parent The original x dimension
* \param y_parent The original y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \param p_x_outer Outer axis of x dimension
* \param p_y_outer Outer axis of y dimension
* \param p_x_inner Inner axis of x dimension
* \param p_y_inner Inner axis of y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \return reference to self.
*/
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
IterVar* p_x_inner, IterVar* p_y_inner);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
......@@ -164,7 +163,15 @@ class Stage : public NodeRef {
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
inline bool is_scheduled() const;
bool is_scheduled() const;
/*!
* \brief Get attachment spec of current stage.
* If the stage compute at Group root, this function
* will traverse the group function to get the
* final spec from the group.
* \return A stage representing the attach spec of the group.
*/
Stage GetAttachSpec() const;
// declare container type
using ContainerType = StageNode;
};
......@@ -197,6 +204,18 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op);
}
/*!
* \brief Create a new stage group for all intermediate
* operations between inputs and outputs.
*
* \param outputs The output boundary of the group.
* \param inputs The input boundary of the group.
* \param include_inputs Whether include inputs if they are reachable from outputs.
* \return The new grouped stage.
*/
Stage create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs = false);
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
......@@ -274,7 +293,6 @@ class IterVarRelation : public NodeRef {
class IterVarAttr : public NodeRef {
public:
IterVarAttr() {}
explicit IterVarAttr(IterVarType t);
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
......@@ -283,26 +301,27 @@ class IterVarAttr : public NodeRef {
inline const IterVarAttrNode* operator->() const;
};
// defintion of node containers
/*!
* \brief represents the schedule of the tensor
* \brief represents a stage.
*
* A schedule is a Directed acylic hypergraph.
* relations form a Directed acylic hypergraph in bipartite manner.
* With each node is represented by a IterVar,
* and each hyper-edge is represented by a IterVarRelation.
* The relations connects the IterVars in the graph.
*
* The relations can be Split/Fuse.
*
* The current data structure stores the hyper graph in its
* bipartite representation.
* Besides typical stage that corresponds to operations.
* There is also group stage, which groups stages together.
* Each stage's group(given by group) represent an constraint,
* the stage can only be attached to stages within the group.
*
* The relations connects the IterVars in the graph.
* The group stage node can be attached to IterVars as in normal stage.
*/
class StageNode : public Node {
public:
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief The operation of stage, can be different from original op. */
/*!
* \brief The operation of stage, can be different from original op.
* If it is null, then this stage is a group stage.
*/
Operation op;
/*!
* \brief The original operator.
......@@ -312,42 +331,50 @@ class StageNode : public Node {
Operation origin_op;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
/*! \brief The current active leaf iter vars in the stage. */
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array<IterVar> outermost_threads;
Array<IterVar> env_threads;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Map<IterVar, IterVarAttr> iter_var_attrs;
/*! \brief The attachment type of the schedule */
AttachType attach_type{kNone};
AttachType attach_type{kGroupRoot};
/*! \brief The attach point of this schedule. */
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief The thread storage scope level of the stage */
std::string scope;
/*! \brief Whether this is an output stage */
bool is_output{false};
/*!
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
*/
Stage group;
/*! \brief Number of direct child stages, only used for group stage.*/
int num_child_stages{0};
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads);
v->Visit("env_threads", &env_threads);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
v->Visit("scope", &scope);
v->Visit("is_output", &is_output);
v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages);
}
static constexpr const char* _type_key = "Stage";
......@@ -360,19 +387,34 @@ class ScheduleNode : public Node {
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
/*!
* \brief list of all stages for non-placeholder ops.
* \brief list of all stages for ops.
* The stages are sorted in dependency order.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
/*!
* \brief List of all stage groups.
*/
Array<Stage> groups;
/*! \brief map of original operation to the stages */
Map<Operation, Stage> stage_map;
/*!
* \brief Internal stage map to map internal ops to stages.
* This is created on demand and can be invalidated.
*/
std::unordered_map<const Node*, Stage> op2stage_cache_;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("groups", &groups);
v->Visit("stage_map", &stage_map);
}
/*! \brief Initialize temp cache. */
void InitCache();
/*! \brief Invalidate temp cache. */
void InvalidateCache();
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
};
......@@ -381,10 +423,13 @@ class ScheduleNode : public Node {
class IterVarAttrNode : public Node {
public:
/*! \brief The iteration type. */
IterVarType iter_type;
IterVarType iter_type{kDataPar};
/*! \brief The thread this iter Var binds, can be null */
IterVar bind_thread;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
}
static constexpr const char* _type_key = "IterVarAttr";
......@@ -412,17 +457,22 @@ class SplitNode : public IterVarRelationNode {
IterVar inner;
/*! \brief The split factor */
Expr factor;
/*! \brief Number of parts, only factor or nparts can be given */
Expr nparts;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("factor", &factor);
v->Visit("nparts", &nparts);
}
static IterVarRelation make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor);
static IterVarRelation make(IterVar parent,
IterVar outer,
IterVar inner,
Expr factor,
Expr nparts);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
......@@ -485,12 +535,6 @@ inline StageNode* Stage::operator->() {
return static_cast<StageNode*>(node_.get());
}
inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}
inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
......@@ -505,6 +549,5 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SCHEDULE_H_
......@@ -11,6 +11,7 @@ from ._ctypes._function import Function
from ._ctypes._function import _init_api_functions, register_func, get_global_func
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import _base
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
......@@ -142,7 +143,7 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0)
def scan(init, update, state_placeholder, name="scan"):
def scan(init, update, state_placeholder, inputs=None, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
......@@ -156,6 +157,10 @@ def scan(init, update, state_placeholder, name="scan"):
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional
The name hint of the tensor
......@@ -173,7 +178,7 @@ def scan(init, update, state_placeholder, name="scan"):
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
res = tvm.scan(s_init, s_update, s_state, X)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
......@@ -181,10 +186,14 @@ def scan(init, update, state_placeholder, name="scan"):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
......@@ -340,20 +349,25 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
return _api_internal._IterVar(dom, var, iter_type, thread_tag)
def thread_axis(dom, tag, name=''):
def thread_axis(dom=None, tag='', name=''):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range
The domain of iteration.
dom : Range or str
The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str
tag : str, optional
The thread tag
name : str, optional
The name of the var.
"""
if isinstance(dom, _base.string_types):
tag, dom = dom, None
if len(tag) == 0:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return _IterVar(dom, name, 1, tag)
......
......@@ -41,6 +41,30 @@ class Schedule(NodeBase):
"""
_api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
The operators between outputs and inputs are placed as member of group.
outputs are include in the group, while inputs are not included.
Parameters
----------
outputs : list of Tensors
The outputs of the group.
inputs : list of Tensors
The inputs of the group.
include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs.
"""
if isinstance(outputs, _tensor.Tensor):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _api_internal._ScheduleCreateGroup(
self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
......@@ -112,25 +136,7 @@ class Schedule(NodeBase):
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def rebase(self, parent, rebased):
"""Rebase parent by an existing thread axis.
Parameters
----------
parent : IterVar
The parent iter var.
rebased : IterVar
The rebased iter var.
Returns
-------
rebased : IterVar
The rebased itervar.
"""
_api_internal._StageRebase(self, parent, rebased)
return rebased
def split(self, parent, factor=None, outer=None):
def split(self, parent, factor=None, nparts=None):
"""Split the stage either by factor providing outer scope, or both
Parameters
......@@ -141,8 +147,8 @@ class Stage(NodeBase):
factor : Expr, optional
The splitting factor
outer : IterVar, optional
The outer split variable
nparts : Expr, optional
The number of outer parts.
Returns
-------
......@@ -152,11 +158,13 @@ class Stage(NodeBase):
inner : IterVar
The inner variable of iteration.
"""
if outer is not None:
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
if nparts is not None:
if factor is not None:
raise ValueError("Donot need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
else:
if factor is None:
raise ValueError("either outer or factor need to be provided")
raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
......@@ -188,8 +196,21 @@ class Stage(NodeBase):
"""
return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.
def bind(self, ivar, thread_ivar):
"""Bind ivar to thread index thread_ivar
Parameters
----------
ivar : IterVar
The iteration to be binded to thread.
thread_ivar : IterVar
The thread to be binded.
"""
_api_internal._StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
Parameters
----------
......@@ -198,7 +219,7 @@ class Stage(NodeBase):
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)
_api_internal._StageEnvThreads(self, threads)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
......
......@@ -182,7 +182,8 @@ TVM_REGISTER_API(_ScanOp)
args[1],
args[2],
args[3],
args[4]);
args[4],
args[5]);
});
TVM_REGISTER_API(_ExternOp)
......@@ -219,27 +220,26 @@ TVM_REGISTER_API(_StageSetScope)
.set_scope(args[1]);
});
TVM_REGISTER_API(_StageRebase)
TVM_REGISTER_API(_StageBind)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.rebase(args[1], args[2]);
.bind(args[1], args[2]);
});
TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.split(args[1], &outer, &inner, args[2]);
.split(args[1], args[2], &outer, &inner);
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageSplitByOuter)
TVM_REGISTER_API(_StageSplitByNParts)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar inner;
IterVar outer, inner;
args[0].operator Stage()
.split(args[1], args[2], &inner, args[3]);
*ret = inner;
.split_by_nparts(args[1], args[2], &outer, &inner);
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageFuse)
......@@ -278,15 +278,17 @@ TVM_REGISTER_API(_StageTile)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage()
.tile(args[1], args[2], &x_outer, &y_outer,
&x_inner, &y_inner, args[3], args[4]);
.tile(args[1], args[2],
args[3], args[4],
&x_outer, &y_outer,
&x_inner, &y_inner);
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_StageOutermostThreads)
TVM_REGISTER_API(_StageEnvThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
.env_threads(args[1]);
});
TVM_REGISTER_API(_StageUnroll)
......@@ -313,6 +315,12 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});
TVM_REGISTER_API(_ScheduleCreateGroup)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
......
......@@ -34,9 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule
......
......@@ -35,7 +35,6 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
}
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
LOG(INFO) << "ssa get id";
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
......
......@@ -132,7 +132,6 @@ void ComputeOpNode::PropBoundToInputs(
void ComputeOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
......
......@@ -99,7 +99,6 @@ void ExternOpNode::PropBoundToInputs(
void ExternOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
......
......@@ -38,20 +38,29 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = iv->var;
continue;
}
// Bind iv could be another thread.
IterVar bind_iv = iv;
if (stage->iter_var_attrs.count(iv)) {
IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
if (bind_thread.defined()) bind_iv = bind_thread;
}
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
Var var = iv->var;
Var var = bind_iv->var;
if (new_loop_var) {
var = Var(iv->var->name_hint + ".init", iv->var.type());
var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
}
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type
<< " in the iter_var_attrs";
......@@ -67,7 +76,7 @@ MakeLoopNest(const Stage& stage,
for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(iv->var->name_hint + ".idx", iv->var.type());
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
......@@ -76,29 +85,29 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
} else if (iv->thread_tag == "vthread") {
} else if (bind_iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else if (iv->thread_tag == "pipeline") {
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
CHECK(is_zero(dom->min));
CHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
......
......@@ -33,7 +33,6 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
......@@ -66,7 +65,6 @@ void PlaceholderOpNode::PropBoundToInputs(
void PlaceholderOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
......
......@@ -47,7 +47,8 @@ Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
......@@ -89,12 +90,14 @@ Operation ScanOpNode::make(std::string name,
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
n->inputs = inputs;
return Operation(n);
}
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs,
std::string name) {
IterVar scan_axis =
IterVarNode::make(
......@@ -102,7 +105,7 @@ Array<Tensor> scan(Array<Tensor> init,
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
name, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
......@@ -179,7 +182,6 @@ void ScanOpNode::PropBoundToInputs(
void ScanOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
......@@ -200,8 +202,7 @@ void ScanOpNode::GatherBound(
Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(this, graph_ctx.feed_graph);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self, body);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
......
......@@ -15,10 +15,23 @@
namespace tvm {
namespace schedule {
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
// check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
inline bool ScopeRelax(const IterVar& ivar,
const std::unordered_map<IterVar, IterVar>& bind_map,
const std::string& scope) {
using runtime::ThreadScope;
using runtime::StorageScope;
auto it = bind_map.find(ivar);
IterVar iv = ivar;
if (it != bind_map.end()) {
iv = it->second;
}
if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false;
......@@ -28,10 +41,16 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
void InferRootBound(const Stage& stage,
const GraphContext& ctx,
const AttachPath& attach_path,
const std::unordered_map<IterVar, IterVar>& bind_map,
std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output) {
// verify correctness.
CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
<< "Output must be attached at root";
}
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined());
......@@ -42,8 +61,10 @@ void InferRootBound(const Stage& stage,
}
// parent stage, if any
Stage parent;
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) {
parent = stage->attach_stage;
Stage attach_spec = stage.GetAttachSpec();
if (attach_spec->attach_type == kScope ||
attach_spec->attach_type == kScanUpdate) {
parent = attach_spec->attach_stage;
}
// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
......@@ -72,13 +93,11 @@ void InferRootBound(const Stage& stage,
// from the already inferred bounds.
std::unordered_map<const Variable*, IntSet> relax_set;
for (IterVar iv : attach_path.at(stage->op)) {
if (ScopeRelax(iv, stage->scope)) {
if (ScopeRelax(iv, bind_map, stage->scope)) {
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
}
}
if (direct_consume_by_parent) {
// parent stage if exist
Stage parent = stage->attach_stage;
// Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
......@@ -89,16 +108,16 @@ void InferRootBound(const Stage& stage,
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. "
<< " stage=" << parent;
<< " stage=" << parent << ", vrange=" << vrange->min;
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
} else if (fix_value && !ScopeRelax(iv, stage->scope)) {
} else if (fix_value && !ScopeRelax(iv, bind_map, stage->scope)) {
up_state[iv] = IntSet::single_point(iv->var);
} else {
up_state[iv] = IntSet::range(vrange);
}
if (stage->attach_ivar == iv) {
if (attach_spec->attach_ivar == iv) {
fix_value = false;
}
}
......@@ -159,7 +178,7 @@ void InferRootBound(const Stage& stage,
}
op->PropBoundToInputs(op, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, ctx, tmap, rmap);
stage->op->GatherBound(stage->op, tmap, rmap);
}
Map<IterVar, Range> InferBound(const Schedule& sch) {
......@@ -167,18 +186,25 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
std::unordered_map<IterVar, IterVar> bind_map;
for (Stage stage : sch->stages) {
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(!bind_map.count(kv.first));
bind_map[kv.first] = kv.second->bind_thread;
}
}
}
GraphContext ctx;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, attach_path, &ret);
InferRootBound(stage, ctx, attach_path, bind_map, &ret);
// pass down to get bound of all iter vars.
PassDownDomain(stage, &ret);
// setup outer most threads.
for (IterVar iv : stage->outermost_threads) {
for (IterVar iv : stage->env_threads) {
CHECK(iv->dom.defined());
ret[iv] = iv->dom;
}
......
......@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <unordered_set>
#include <unordered_map>
#include "./graph.h"
namespace tvm {
......@@ -82,6 +83,60 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
return rmap;
}
// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
bool GetSubGraphByPostDFS_(
const Operation& op,
const std::unordered_set<const Node*>& boundary,
bool include_bounary,
std::unordered_map<const Node*, bool>* visited,
Array<Operation>* result) {
if (visited->count(op.get())) {
return visited->at(op.get());
}
if (boundary.count(op.get())) {
(*visited)[op.get()] = true;
if (include_bounary) {
result->push_back(op);
}
return true;
}
// mark to avoid loop
// Not necessary for DAG.
(*visited)[op.get()] = false;
// check if we can reach boundary.
bool reach_boundary = false;
for (Tensor t : op->InputTensors()) {
if (GetSubGraphByPostDFS_(t->op, boundary,
include_bounary,
visited, result)) {
reach_boundary = true;
}
}
(*visited)[op.get()] = reach_boundary;
if (reach_boundary) {
result->push_back(op);
}
return reach_boundary;
}
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs) {
Array<Operation> result;
std::unordered_set<const Node*> boundary;
for (Tensor t : inputs) {
boundary.insert(t->op.get());
}
std::unordered_map<const Node*, bool> visited;
for (Tensor t : outputs) {
GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
&visited, &result);
}
return result;
}
void PostDFSOrder(const Operation& op,
const ReadGraph& g,
std::unordered_set<Operation>* visited,
......@@ -118,30 +173,38 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) {
AttachPath CreateAttachPath(Schedule sch) {
AttachPath ret;
for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
for (Stage stage : sch->stages) {
std::unordered_set<const Node*> visited;
Array<IterVar> path;
for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) {
IterVar attach_ivar = s->attach_ivar;
s = s->attach_stage;
bool start_attach = false;
for (Stage s = stage; s.defined();) {
CHECK(!visited.count(s.get()))
<< "Find loop in compute_at attach group";
visited.insert(s.get());
Stage spec = s.GetAttachSpec();
bool start_attach;
IterVar attach_ivar;
if (spec->attach_type == kScope) {
attach_ivar = spec->attach_ivar;
s = spec->attach_stage;
start_attach = false;
CHECK(attach_ivar.defined());
} else if (spec->attach_type == kScanUpdate) {
s = spec->attach_stage;
start_attach = true;
} else {
break;
}
CHECK(s.defined());
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true;
if (!start_attach && iv.same_as(attach_ivar)) {
start_attach = true;
}
if (start_attach) path.push_back(iv);
}
CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op;
}
if (!ret.count(stage->op)) {
ret.Set(stage->op, path);
}
......@@ -203,53 +266,22 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
return reach;
}
// Get all the operations that forms body of scan
void ScanGetBodyPostDFS_(
Operation op,
const ScanOpNode* scan,
const FeedGraph& feed_graph,
std::unordered_set<const Node*>* visited,
Array<Operation>* result) {
if (op.get() == scan) return;
bool empty_feed = true;
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = feed_graph.find(op.output(i));
if (it != feed_graph.end() && it->second.size()) {
empty_feed = false;
for (const Operation& xop : it->second) {
if (visited->count(xop.get())) continue;
visited->insert(xop.get());
ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result);
result->push_back(xop);
}
}
}
if (empty_feed && op.get() != scan) {
LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan";
}
}
Array<Operation> ScanGetBody_(
const ScanOpNode* scan,
const FeedGraph& feed_graph) {
CHECK(scan != nullptr);
std::unordered_set<const Node*> visited;
Array<Operation> result;
Array<Operation> ScanGetBody(const Operation& scan_op) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>();
// Get the body.
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result);
inputs.push_back(t);
}
return result;
}
Array<Operation> ScanGetBody(const Operation& scan) {
return ScanGetBody_(scan.as<ScanOpNode>(),
CreateFeedGraph(CreateReadGraph({scan})));
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
return GetSubGraph(scan->update, inputs, false);
}
Map<IterVar, Expr> ScanFixPointAnalysis(
const Operation& scan_op, const Array<Operation>& body) {
Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>();
CHECK(body[0].get() == scan);
Array<Operation> body = ScanGetBody(scan_op);
std::unordered_map<TensorDimKey, const Node*> exact_reach;
std::unordered_set<const Node*> fail_set;
......@@ -276,8 +308,8 @@ Map<IterVar, Expr> ScanFixPointAnalysis(
}
};
// prop exact reach back.
for (size_t i = body.size(); i != 1; --i) {
const Operation& op = body[i - 1];
for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i];
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
......
......@@ -27,6 +27,11 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >;
/*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*!
* \brief Get read graph of each operation to all the
* Tensors that it directly depends on.
*
......@@ -37,6 +42,23 @@ using AttachPath = Map<Operation, Array<IterVar> >;
ReadGraph CreateReadGraph(const Array<Operation>& roots);
/*!
* \brief Get minimum subgraph between outputs and inputs.
* The operations contains node which input-reachable from any inputs
* output reachable to any outputs.
*
* The inputs won't be included in the subgraph, the outputs will be inclued.
*
* \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph.
* \param include_inputs Whether to include inputs
*
* \return The subgraph.
*/
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs);
/*!
* \brief Get a post DFS ordered of operations in the graph.
* \param roots The root of the graph.
* \param g The read graph.
......@@ -67,14 +89,10 @@ AttachPath CreateAttachPath(Schedule sch);
/*!
* \brief Get all operations inside the recursion of scan.
* \param scan The scan node.
* \param feed_graph The feed graph to help analysis.
* \param scan_op The scan node ops.
* \return The body operations, in read dependency order.
*/
Array<Operation> ScanGetBody_(
const ScanOpNode* scan, const FeedGraph& feed_graph);
// same as ScanGetBody_, but create FeedGraph internally.
Array<Operation> ScanGetBody(const Operation& scan);
Array<Operation> ScanGetBody(const Operation& scan_op);
/*!
* \brief Analyze each spatial dimension of scan's result.
......@@ -85,11 +103,9 @@ Array<Operation> ScanGetBody(const Operation& scan);
* next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
*
* \param scan The scan node.
* \param body The body of scan, sorted in reverse PostDFSOrder.
* \return Map of spatial_axis -> IntImm
*/
Map<IterVar, Expr> ScanFixPointAnalysis(
const Operation& scan, const Array<Operation>& body);
Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
} // namespace schedule
} // namespace tvm
......
......@@ -22,6 +22,23 @@ inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
} else {
bool match = is_zero(it->second->min);
if (!prove_equal(r->extent, it->second->extent)) match = false;
CHECK(match)
<< iv
<< " domain already inferred,"
<< " cannot prove their extents are the same "
<< it->second->extent << " vs " << r->extent;
}
}
void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
bool allow_missing) {
......@@ -36,30 +53,15 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor);
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
Update(p_state, r->inner, Range::make_with_min_extent(0, r->factor));
Update(p_state, r->outer,
Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor)));
} else {
if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< r->outer
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same "
<< outer_ext << " vs " << outer_rng->extent;
}
}
} else {
CHECK(r->outer->dom.defined());
state[r->outer] = r->outer->dom;
state[r->inner] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent));
Update(p_state, r->outer, Range::make_with_min_extent(0, r->nparts));
Update(p_state, r->inner,
Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->nparts)));
}
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) {
......@@ -75,20 +77,20 @@ void PassDownDomain(const Stage& stage,
CHECK(allow_missing);
continue;
}
Range res = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
if (r->rebased->dom.defined()) {
Range rebase_rng = r->rebased->dom;
bool match = is_zero(rebase_rng->min);
if (!prove_equal(rebase_rng->extent, res->extent)) match = false;
CHECK(match) << r->rebased
<< " does not match parent scope's range";
}
state[r->rebased] = res;
Update(p_state, r->rebased,
Range::make_with_min_extent(
0, state.at(r->parent)->extent));
} else {
LOG(FATAL) << "unknown relation type";
}
}
// update the extents of binded threads.
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(state.count(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first));
}
}
}
void PassUpIndex(const Stage& stage,
......
......@@ -55,6 +55,7 @@ void ReplaceDataFlow(const Array<Stage>& stages,
Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers) {
(*this)->InvalidateCache();
// create identity mapping.
std::ostringstream os;
os << tensor->op->name;
......@@ -81,18 +82,25 @@ Tensor Schedule::cache_read(const Tensor& tensor,
}
ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op));
Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages, op_stage);
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache;
}
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
......@@ -123,7 +131,6 @@ Tensor Schedule::cache_write(const Tensor& tensor,
std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
......@@ -137,6 +144,11 @@ Tensor Schedule::cache_write(const Tensor& tensor,
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache_tensor;
}
......@@ -152,6 +164,11 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
attach_mark[s.get()] = 1;
}
}
for (Stage s : sch->groups) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue;
......@@ -176,6 +193,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
void SetScanAttach(const Schedule& sch) { // NOLINT(*)
......@@ -188,8 +211,8 @@ void SetScanAttach(const Schedule& sch) { // NOLINT(*)
}
}
void InjectInline(const Schedule& sch) {
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size());
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
......@@ -241,12 +264,13 @@ void InjectInline(const Schedule& sch) {
void Schedule::normalize() {
RebaseNonZeroMinLoop(*this);
SetScanAttach(*this);
InjectInline(*this);
InjectInline(operator->());
}
// Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
(*this)->InvalidateCache();
using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis";
......
......@@ -36,39 +36,28 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
return 0;
}
void CheckSplit(StageNode* self, IterVar parent, IterVar outer) {
void Split(StageNode* self,
IterVar parent,
Expr factor,
Expr nparts,
IterVar* p_outer,
IterVar* p_inner) {
// Check if split is valid.
if (self->attach_type == kScanUpdate) {
CHECK(!parent.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
if (outer.defined()) {
if (outer->iter_type == kThreadIndex) {
CHECK_EQ(parent->iter_type, kDataPar)
<< "Split by by kThreadIndex requires kDataPar IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else if (outer->iter_type == kCommReduce) {
CHECK_EQ(parent->iter_type, kCommReduce)
<< "Split by by kCommReduce requires kCommReduce IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else {
LOG(FATAL) << "Cannot take " << IterVarType2String(parent->iter_type)
<< " as outer IterVar";
}
} else {
CHECK(parent->iter_type == kDataPar ||
parent->iter_type == kCommReduce ||
parent->iter_type == kOrdered)
<< "Cannot split on " << IterVarType2String(parent->iter_type);
}
}
void Split(StageNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) {
IterVar outer = IterVarNode::make(
Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
IterVar inner = IterVarNode::make(
Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
*p_outer = outer;
*p_inner = inner;
// The splits
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
self->relations.push_back(SplitNode::make(parent, outer, inner, factor));
self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
// add vars to all vars
all_vars->data.push_back(outer.node_);
all_vars->data.push_back(inner.node_);
......@@ -98,6 +87,21 @@ Stage::Stage(Operation op) {
node_ = n;
}
bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kGroupRoot &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}
Stage Stage::GetAttachSpec() const {
Stage attach_spec = *this;
while (attach_spec->attach_type == kGroupRoot &&
attach_spec->group.defined()) {
attach_spec = attach_spec->group;
}
return attach_spec;
}
Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
(*this)->scope = scope;
return *this;
......@@ -106,6 +110,17 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates";
// Group constraint checking.
Stage group = (*this)->group;
if (group.defined()) {
Stage pg = parent->group;
while (pg.defined() && !pg.same_as(group)) {
pg = pg->group;
}
CHECK(pg.same_as(group))
<< "Can only assign compute_at to stages within the same group";
}
(*this)->attach_type = kScope;
(*this)->attach_ivar = scope;
(*this)->attach_stage = parent;
......@@ -117,7 +132,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
CHECK(found)
<< "Cannot find the axis " << scope
<< " in parent's leaf_iter_vars or outermost_threads:"
<< " in parent's leaf_iter_vars"
<< " parent=" << parent;
return *this;
}
......@@ -132,61 +147,73 @@ Stage& Stage::compute_inline() { // NOLINT(*)
Stage& Stage::compute_root() { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates";
(*this)->attach_type = kRoot;
(*this)->attach_type = kGroupRoot;
return *this;
}
Stage& Stage::rebase(IterVar parent, IterVar rebased) { // NOLINT(*)
CHECK(parent->iter_type == kDataPar ||
parent->iter_type == kCommReduce)
<< "Cannot rebase " << IterVarType2String(parent->iter_type);
CHECK(rebased->iter_type == kThreadIndex)
<< "Cannot rebase by " << IterVarType2String(rebased->iter_type)
Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
StageNode* self = operator->();
CHECK(ivar->iter_type == kDataPar ||
ivar->iter_type == kCommReduce)
<< "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
CHECK(thread_ivar->iter_type == kThreadIndex)
<< "Cannot rebase by " << IterVarType2String(ivar->iter_type)
<< ", only thread axis is allowed so far";
ArrayNode* all_vars = (*this)->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = (*this)->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
(*this)->relations.push_back(RebaseNode::make(parent, rebased));
// add vars to all vars
all_vars->data.push_back(rebased.node_);
// replace the position.
leaf_vars->data.erase(leaf_vars->data.begin() + pos);
leaf_vars->data.insert(leaf_vars->data.begin() + pos, rebased.node_);
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, ivar);
auto it = self->iter_var_attrs.find(ivar);
std::shared_ptr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
if (n->bind_thread.defined() &&
!n->bind_thread.same_as(thread_ivar)) {
LOG(WARNING) << "Axis " << ivar
<< " is already bind to another thread " << n->bind_thread;
}
} else {
n = std::make_shared<IterVarAttrNode>();
}
n->bind_thread = thread_ivar;
self->iter_var_attrs.Set(ivar, IterVarAttr(n));
return *this;
}
Stage& Stage::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
CheckSplit(operator->(), parent, IterVar());
IterVar outer = IterVarNode::make(
Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
IterVar inner = IterVarNode::make(
Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
*p_outer = outer;
*p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
Stage& Stage::env_threads(Array<IterVar> threads) {
StageNode* self = operator->();
CHECK(self->op.defined() && self->op.as<ScanOpNode>())
<< "env_threads is only valid for composite ops such as ScanOp";
CHECK_EQ(self->env_threads.size(), 0U)
<< "Already set env_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<std::shared_ptr<Node> > temp;
for (IterVar iv : threads) {
temp.push_back(iv.node_);
}
leaf_vars->data.insert(
leaf_vars->data.begin(), temp.begin(), temp.end());
all_vars->data.insert(
all_vars->data.end(), temp.begin(), temp.end());
self->env_threads = threads;
return *this;
}
Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
CheckSplit(operator->(), parent, outer);
std::string name_inner = parent->var->name_hint + ".inner";
IterVar inner = IterVarNode::make(
Range(), Var(name_inner, parent->var.type()), parent->iter_type);
*p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
Stage& Stage::split(
IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
Split(operator->(), parent, factor, Expr(), p_outer, p_inner);
return *this;
}
Stage& Stage::split_by_nparts(
IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
Split(operator->(), parent, Expr(), nparts, p_outer, p_inner);
return *this;
}
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
StageNode* self = operator->();
if (self->attach_type == kScanUpdate) {
CHECK(!inner.same_as(self->all_iter_vars[0]))
<< "Cannot fuse on axis[0] of scan update";
CHECK(!outer.same_as(self->all_iter_vars[0]))
<< "Cannot fuse on axis[0] of scan update";
}
CHECK(outer->iter_type == kDataPar ||
outer->iter_type == kCommReduce ||
outer->iter_type == kOrdered)
......@@ -236,10 +263,6 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
std::vector<size_t> pos;
for (size_t i = 0; i < order.size(); ++i) {
if ((*this)->attach_type == kScanUpdate) {
CHECK(!order[i].same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
}
std::vector<std::shared_ptr<Node> > temp;
......@@ -254,66 +277,48 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
}
Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor);
IterVar* p_x_inner, IterVar* p_y_inner) {
split(x_parent, x_factor, p_x_outer, p_x_inner);
split(y_parent, y_factor, p_y_outer, p_y_inner);
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this;
}
Stage& Stage::outermost_threads(Array<IterVar> threads) {
StageNode* self = operator->();
CHECK(self->op.as<ScanOpNode>())
<< "outermost_threads is only valid for composite ops such as ScanOp";
CHECK_EQ(self->outermost_threads.size(), 0U)
<< "Already set outermost_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<std::shared_ptr<Node> > temp;
for (IterVar iv : threads) {
temp.push_back(iv.node_);
}
leaf_vars->data.insert(
leaf_vars->data.begin(), temp.begin(), temp.end());
all_vars->data.insert(
all_vars->data.end(), temp.begin(), temp.end());
(*this)->outermost_threads = threads;
return *this;
}
inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) {
inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
CHECK_EQ((*it).second->iter_type, attr->iter_type)
<< "IterVar's is already set to "
<< (*it).second << " instead of " << attr;
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
} else {
self->iter_var_attrs.Set(var, attr);
n = std::make_shared<IterVarAttrNode>();
}
n->iter_type = iter_type;
self->iter_var_attrs.Set(var, IterVarAttr(n));
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kVectorized));
SetAttrIterType(operator->(), var, kVectorized);
return *this;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kUnrolled));
SetAttrIterType(operator->(), var, kUnrolled);
return *this;
}
Stage& Stage::parallel(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kParallelized));
SetAttrIterType(operator->(), var, kParallelized);
return *this;
}
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
node_ = n;
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
......@@ -330,14 +335,24 @@ Schedule::Schedule(Array<Operation> ops) {
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
}
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
// Create the scan group.
Stage scan_group = create_group(scan->update, inputs, false);
scan_group->attach_type = kScanUpdate;
scan_group->attach_stage = stage;
for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op];
s->attach_type = kScanUpdate;
s->attach_stage = stage;
CHECK(scan_group.same_as(s->group));
}
}
}
node_ = std::move(n);
}
Stage Schedule::operator[](const Operation& op) {
......@@ -348,14 +363,174 @@ Stage Schedule::operator[](const Operation& op) {
return (*it).second;
}
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
Stage LeastCommonAncestor(Stage g1, Stage g2) {
if (!g1.defined()) return g1;
if (!g2.defined()) return g2;
if (g1.same_as(g2)) return g1;
Stage g = g1;
while (g.defined()) {
if (g.same_as(g2)) return g2;
g = g->group;
}
g = g2;
while (g.defined()) {
if (g.same_as(g1)) return g1;
g = g->group;
}
return g;
}
Array<Tensor> RemapTensor(ScheduleNode* self,
const Array<Tensor>& arr) {
self->InitCache();
const auto& op2stage_cache = self->op2stage_cache_;
Array<Tensor> ret;
for (Tensor t : arr) {
if (!op2stage_cache.count(t->op.get())) {
CHECK(self->stage_map.count(t->op))
<< "Given tensor is not in the schedule plan";
t = self->stage_map[t->op]->op.output(t->value_index);
}
ret.push_back(t);
}
return ret;
}
// Group the schedule stages.
Stage Schedule::create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs) {
ScheduleNode* self = operator->();
self->InitCache();
const auto& op2stage_cache = self->op2stage_cache_;
// Get the ops.
Array<Operation> ops = schedule::GetSubGraph(
RemapTensor(self, outputs),
RemapTensor(self, inputs),
include_inputs);
// local counter entry
// Automatically initialize to 0 during creation.
struct Entry {
int count{0};
};
// Map of group->touched counter
std::unordered_map<Stage, Entry, NodeHash, NodeEqual> counter;
// The parent group;
Stage parent_group;
// Detect common parent and child.
for (size_t i = 0; i < ops.size(); ++i) {
Operation op = ops[i];
auto it = op2stage_cache.find(op.get());
CHECK(it != op2stage_cache.end());
Stage op_group = it->second->group;
if (i == 0) {
parent_group = op_group;
} else {
parent_group = LeastCommonAncestor(parent_group, op_group);
}
if (op_group.defined()) {
++counter[op_group].count;
}
}
// Create the new group stage.
Stage gstage(std::make_shared<StageNode>());
gstage->group = parent_group;
if (parent_group.defined()) {
++parent_group->num_child_stages;
}
// Propagate the counter statistics from by checking if subgroup
// Is full and propagate.
std::vector<Stage> stack;
for (auto &kv : counter) {
if (!kv.first.same_as(parent_group)) {
if (kv.first->num_child_stages == kv.second.count) {
stack.push_back(kv.first);
}
}
}
while (!stack.empty()) {
Stage g = stack.back();
stack.pop_back();
if (g->group.defined() && !g->group.same_as(parent_group)) {
Entry& e = counter[g->group];
++e.count;
if (e.count == g->group->num_child_stages) {
stack.push_back(g->group);
}
}
}
// Verification and remappig the subgroups.
for (auto &kv : counter) {
if (kv.first.same_as(parent_group)) continue;
CHECK_EQ(kv.first->num_child_stages, kv.second.count)
<< "Trying to group region that intersect with an already existed group";
if (kv.first->group.same_as(parent_group)) {
Stage s = kv.first;
s->group = gstage;
++gstage->num_child_stages;
if (parent_group.defined()) {
--parent_group->num_child_stages;
}
}
}
// Remap the group of op stages.
for (Operation op : ops) {
auto it = op2stage_cache.find(op.get());
CHECK(it != op2stage_cache.end());
Stage s = it->second;
if (s->group.same_as(parent_group)) {
s->group = gstage;
++gstage->num_child_stages;
if (parent_group.defined()) {
--parent_group->num_child_stages;
}
}
}
// Correct the attach to keep everything in group.
for (Operation op : ops) {
auto it = op2stage_cache.find(op.get());
CHECK(it != op2stage_cache.end());
Stage s = it->second;
if (s->attach_type == kScope) {
Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
if (!cg.same_as(gstage)) {
LOG(WARNING) << "group invalidates some previous compute_at relation "
<< " and keeps things to be computed inside the group";
s.compute_root();
}
}
}
self->groups.push_back(gstage);
return gstage;
}
void ScheduleNode::InvalidateCache() {
op2stage_cache_.clear();
}
void ScheduleNode::InitCache() {
if (op2stage_cache_.size() == stages.size()) return;
InvalidateCache();
for (Stage s : stages) {
if (s->op.defined()) {
op2stage_cache_[s->op.get()] = s;
}
}
CHECK_EQ(op2stage_cache_.size(), stages.size());
}
IterVarRelation SplitNode::make(IterVar parent,
IterVar outer,
IterVar inner,
Expr factor,
Expr nparts) {
auto n = std::make_shared<SplitNode>();
n->parent = parent;
n->outer = outer;
n->inner = inner;
n->factor = factor;
n->nparts = nparts;
return IterVarRelation(n);
}
......@@ -375,12 +550,6 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n);
}
IterVarAttr::IterVarAttr(IterVarType t) {
std::shared_ptr<IterVarAttrNode> n = std::make_shared<IterVarAttrNode>();
n->iter_type = t;
node_ = n;
}
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
......@@ -391,7 +560,11 @@ TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
if (op->op.defined()) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
} else {
p->stream << "group-stage(" << op << ")";
}
})
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
p->stream << IterVarType2String(op->iter_type);
......
......@@ -12,6 +12,8 @@
#include <unordered_map>
#include <unordered_set>
#include "./graph.h"
#include "../op/op_util.h"
#include "../pass/ir_util.h"
namespace tvm {
namespace schedule {
......@@ -44,8 +46,9 @@ Stmt MakePipeline(const Stage& s,
class InjectAttach : public IRMutator {
public:
InjectAttach(const Stage& stage,
const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {}
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
......@@ -53,10 +56,11 @@ class InjectAttach : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::loop_scope) {
CHECK_NE(producer_.size(), 0U);
if (op->node == stage_->attach_ivar &&
producer_.back() == stage_->attach_stage->op.get()) {
CHECK(!found_attach);
if (attach_spec_->attach_type == kScope &&
op->node == attach_spec_->attach_ivar) {
CHECK(!found_attach)
<< "Find IterVar" << attach_spec_->attach_ivar
<< " in multiple places in the IR";
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
......@@ -65,26 +69,16 @@ class InjectAttach : public IRMutator {
}
return stmt;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (op->is_producer) {
producer_.push_back(op->func.get());
Stmt ret = IRMutator::Mutate_(op, s);
producer_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
// whether attach point is found
bool found_attach{false};
private:
// the operations to be carried
// The stage.
const Stage& stage_;
// The attach spec, may not contain op.
const Stage& attach_spec_;
// domain map
const std::unordered_map<IterVar, Range>& dom_map_;
// internal stack about realization scope.
std::vector<const Node*> producer_;
};
// inject the operator's realization on the stmt.
......@@ -128,7 +122,6 @@ class InjectScanStep : public IRMutator {
bool is_init_;
};
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
......@@ -157,9 +150,8 @@ class SchedulePostProc : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_init_scope) {
if (op->type_key == attr::loop_scope ||
op->type_key == attr::scan_init_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
......@@ -237,6 +229,15 @@ class SchedulePostProc : public IRMutator {
void Init(const Schedule& sch) {
for (Stage s : sch->stages) {
for (auto kv : s->iter_var_attrs) {
// Update bind thread information.
if (kv.second->bind_thread.defined()) {
const Var& from = kv.first->var;
const Var& to = kv.second->bind_thread->var;
CHECK(!var_value_.count(from.get()));
var_value_[from.get()] = to;
}
}
// This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0);
......@@ -279,61 +280,67 @@ class SchedulePostProc : public IRMutator {
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
// scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach;
std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (Tensor t : scan->init) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
if (scan_init.count(t->op)) {
CHECK(scan_init.at(t->op).same_as(s->op))
<< "Scan init tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, true);
}
}
for (Tensor t : scan->update) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan update tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, false);
scan_init[t->op] = s->op;
}
}
}
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
// verify correctness of group.
for (Stage g : sch->groups) {
CHECK(!g->op.defined());
CHECK_EQ(g->leaf_iter_vars.size(), 0U);
}
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
CHECK_NE(s->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
CHECK(s->op.defined());
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
if (scan_attach.count(s->op)) {
CHECK(s->attach_type == kNone ||
s->attach_type == kScanUpdate)
<< "Cannot specify compute_at for scan's init/update";
// Remove grouping sugar, get the real attach spec.
Stage attach_spec = s.GetAttachSpec();
if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
const auto& p = scan_attach.at(s->op);
InjectScanStep mu(s, p.first, dom_map, p.second);
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update";
} else if (s->attach_type == kInlinedAlready) {
<< "did not find attachment point for scan.update";
} else if (attach_spec->attach_type == kInlinedAlready) {
// do nothing
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
} else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
} else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
InjectAttach mutator(s, dom_map);
InjectAttach mutator(s, attach_spec, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in"
<< s->attach_stage->op << " x "
<< "did not find attachment point for " << s << " in "
<< attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
<< ", body:\n"
<< body;
}
}
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace {
......
......@@ -11,11 +11,11 @@ def test_add():
s = tvm.Schedule(C.op)
# create iter var and assign them tags.
num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x)
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
# one line to build the function.
......
......@@ -22,31 +22,41 @@ def test_gemm():
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.thread_axis(None, "blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis("threadIdx.y")
CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x)
s[C].reorder(block_y, block_x, yi, xi)
_, yi = s[C].split(yi, outer=thread_y)
_, xi = s[C].split(xi, outer=thread_x)
s[C].reorder(thread_y, thread_x, yi, xi)
by, yi = s[C].split(C.op.axis[0], factor=block_factor)
bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].reorder(by, bx, yi, xi)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
ty, yi = s[C].split(yi, nparts=num_thread)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].reorder(ty, tx, yi, xi)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo)
s[CC].compute_at(s[C], thread_x)
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k)
_, xi = s[AA].split(s[AA].op.axis[0], outer=thread_y)
_, xi = s[AA].split(xi, outer=thread_x)
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y)
_, xi = s[BB].split(xi, outer=thread_x)
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
max_auto_unroll_step = 0
# lowering test
......@@ -76,9 +86,9 @@ def test_gemm():
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__":
......
......@@ -12,10 +12,9 @@ def test_sum():
s = tvm.Schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
# one line to build the function.
def check_device(device, host="stackvm"):
......@@ -52,10 +51,9 @@ def test_rfactor():
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
......@@ -88,16 +86,14 @@ def test_rfactor_threads():
k = tvm.reduce_axis((0, n))
nthread = 16
B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
tx = tvm.thread_axis((0, nthread), "threadIdx.x")
ty = tvm.thread_axis((0, nthread), "threadIdx.y")
bx = tvm.thread_axis(None, "blockIdx.x")
# schedule
s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx)
s[B].rebase(xi, ty)
s[B].rebase(s[B].op.reduce_axis[0], tx)
bx, tx = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], tx)
# one line to build the function.
......@@ -128,6 +124,6 @@ def test_rfactor_threads():
check_target("opencl")
if __name__ == "__main__":
test_rfactor_threads()
test_rfactor()
test_rfactor_threads()
test_sum()
......@@ -15,10 +15,12 @@ def test_scan():
num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_update].split(x, outer=thread_x)
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
# one line to build the function.
def check_device(device, host="stackvm"):
......
......@@ -11,10 +11,9 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx
num_thread = 256
grid_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x)
_, x = s[C].split(x, outer=thread_x)
xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
# compile to IR
bounds = tvm.schedule.InferBound(s)
......
"""Test group effect"""
import tvm
def test_scan_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
res = tvm.scan(s_init, s_update3, s_state, inputs=x)
s = tvm.Schedule(res.op)
assert s[s_update1].group is not None
assert s[s_update2].group == s[s_update1].group
# Assign within group, is valid
s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
# create a new group, for [s_update2 and s_update1]
g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
assert g2.group is not None
assert g2.group == s[s_update3].group
assert s[s_update2].group == g2
assert s[s_update1].group == g2
g2.compute_at(s[s_update3], s_update3.op.axis[1])
assert g2.attach_stage == s[s_update3]
try:
# compute outside group error.
s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
assert False
except tvm.TVMError:
pass
def test_compute_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x1].group == g
assert s[x].group == g
g.compute_at(s[x2], x2.op.axis[1])
assert g.attach_stage == s[x2]
assert g.num_child_stages == 2
def test_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x1, inputs=x)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert set(s.groups) == set([g1, g2])
assert s[x].group == g2
assert s[x1].group == g1
assert g1.group == g2
assert g2.num_child_stages == 2
assert g1.num_child_stages == 1
if __name__ == "__main__":
test_nest_group()
test_compute_group()
test_scan_group()
......@@ -29,8 +29,8 @@ def test_basic_pipeline():
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline")
xo, xi = s[B].split(B.op.axis[0], outer=px)
xo, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(xo, tvm.thread_axis("pipeline"))
xo, xi = s[B].split(xi, factor=4)
for S in stages:
s[S].compute_at(s[B], xo)
......@@ -50,8 +50,8 @@ def test_conv1d():
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline")
xo, xi = s[B].split(B.op.axis[0], outer=px)
px, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(px, tvm.thread_axis("pipeline"))
s[A].compute_at(s[B], px)
stmt = lower(s, [B])
stmt = tvm.ir_pass.SplitPipeline(stmt, False)
......
......@@ -9,15 +9,15 @@ def test_storage_sync():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op)
block_x = tvm.thread_axis(None, "blockIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, tvm.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
......
......@@ -7,8 +7,9 @@ def test_virtual_thread():
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx)
vx = tvm.thread_axis("vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
s[A2].bind(xo, vx)
xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo)
......
......@@ -36,11 +36,10 @@ def test_bound3():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op)
s[A1].set_scope("shared")
thread_x = tvm.thread_axis((0, 16), "threadIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, outer=thread_x)
xi0, xi1 = s[A2].split(xi, nparts=16)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
yo, yi = s[A2].split(A2.op.axis[1], 16)
s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo)
......@@ -60,12 +59,10 @@ def test_bound_scan():
s_scan = tvm.scan(s_init, s_update, s_state)
assert tuple(s_scan.shape) == (m, n)
s = tvm.Schedule(s_scan.op)
XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo)
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -105,22 +102,59 @@ def test_bound_rfactor():
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4)
assert(bounds[BF.op.axis[1]].extent.value == 1)
def test_bound_group_schedule():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
g.compute_at(s[x2], x2.op.axis[0])
assert s[x1].group == g
assert s[x].group == g
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent == n
def test_bound_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x, inputs=x, include_inputs=True)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x].group == g1
assert s[x1].group == g2
g2.compute_at(s[x2], x2.op.axis[0])
g1.compute_at(s[x1], s[x1].op.axis[1])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent.value == 1
assert bounds[x1.op.axis[0]].extent.value == 1
assert bounds[x1.op.axis[1]].extent == n
if __name__ == "__main__":
test_bound_nest_group()
test_bound_group_schedule()
test_bound_scan()
test_bound3()
test_bound_rfactor()
test_bound_blur()
test_bound_conv1d()
test_bound_scan()
test_bound3()
test_bound1()
test_bound2()
......@@ -59,7 +59,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
......@@ -69,8 +69,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n),
lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
......@@ -89,7 +88,7 @@ def test_scan_fix_point():
[s1_update, s2_update],
[s1, s2])
body = tvm.schedule.ScanGetBody(r0.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op, body)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op)
assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
test_scan0()
......
......@@ -35,8 +35,8 @@ def test_add_pipeline():
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.Schedule(C.op)
grid_x = tvm.thread_axis((0, 1), "pipeline")
_, x = s[C].split(C.op.axis[0], outer=grid_x)
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
print(fsplits[1].body)
......
......@@ -57,9 +57,9 @@ def test_buffer_linebuff():
# correctness checks
if (read_data_valid.get_int()):
# Derive convolution window indices
baseIdx = read_idx/(kernel_width*kernel_width)
offsetIdx = read_idx%(kernel_width*kernel_width)
yOffset = offsetIdx/kernel_width
baseIdx = read_idx // (kernel_width*kernel_width)
offsetIdx = read_idx % (kernel_width*kernel_width)
yOffset = offsetIdx // kernel_width
xOffset = offsetIdx%kernel_width
pixIndex = baseIdx + yOffset * window_width + xOffset
assert(read_data.get_int()==test_data[pixIndex])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment