Commit 8f51c5fd by Tianqi Chen Committed by GitHub

[SCHEDULE] Add group, refactor thread bind api. (#82)

* [SCHEDULE] Add group, refactor thread bind api.

* fix doc

* fix g++-4.8

* More testscase

* Remove graph context from fix pt analysis
parent 6268e183
Subproject commit ce80d58741688b200f498fed8c7b0ea33e0516c8 Subproject commit 59fdca16978b6184bab87fbff7a00c95f1804686
...@@ -32,17 +32,6 @@ struct TensorDom { ...@@ -32,17 +32,6 @@ struct TensorDom {
}; };
/*! /*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
/*!
* \brief Base class of all operation nodes * \brief Base class of all operation nodes
*/ */
class OperationNode : public FunctionBaseNode { class OperationNode : public FunctionBaseNode {
...@@ -102,13 +91,11 @@ class OperationNode : public FunctionBaseNode { ...@@ -102,13 +91,11 @@ class OperationNode : public FunctionBaseNode {
* Set the range of each root_iter_vars in the op to out_dom_map * Set the range of each root_iter_vars in the op to out_dom_map
* *
* \param self The reference to self. * \param self The reference to self.
* \param graph_ctx The global graph context information.
* \param tensor_dom Domain map of Tensor->access set of each dimension. * \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted. * \param out_dom_map The output domain map of each IterVar to be setted.
*/ */
virtual void GatherBound( virtual void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0; std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*! /*!
...@@ -162,7 +149,6 @@ class PlaceholderOpNode : public OperationNode { ...@@ -162,7 +149,6 @@ class PlaceholderOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -214,7 +200,6 @@ class ComputeOpNode : public OperationNode { ...@@ -214,7 +200,6 @@ class ComputeOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -253,6 +238,11 @@ class ScanOpNode : public OperationNode { ...@@ -253,6 +238,11 @@ class ScanOpNode : public OperationNode {
/*! \brief The placeholder to refer as states in update. */ /*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder; Array<Tensor> state_placeholder;
/*! /*!
* \brief the inputs to the scan, these are optionally provided
* But they can be helpful to provide hints to speedup get of scan body.
*/
Array<Tensor> inputs;
/*!
* \brief Spatial axis to indicate spatial dimension of each output. * \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs. * They corresponds to flattened spatial axis of the outputs.
* *
...@@ -279,7 +269,6 @@ class ScanOpNode : public OperationNode { ...@@ -279,7 +269,6 @@ class ScanOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -296,13 +285,15 @@ class ScanOpNode : public OperationNode { ...@@ -296,13 +285,15 @@ class ScanOpNode : public OperationNode {
v->Visit("init", &init); v->Visit("init", &init);
v->Visit("update", &update); v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder); v->Visit("state_placeholder", &state_placeholder);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_); v->Visit("spatial_axis_", &spatial_axis_);
} }
static Operation make(std::string name, static Operation make(std::string name,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder); Array<Tensor> state_placeholder,
Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp"; static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
...@@ -339,7 +330,6 @@ class ExternOpNode : public OperationNode { ...@@ -339,7 +330,6 @@ class ExternOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -388,16 +378,19 @@ Tensor placeholder(Array<Expr> shape, ...@@ -388,16 +378,19 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*! /*!
* \brief Construct new tensors by scan over scan_axis. * \brief Construct new tensors by scan.
* *
* \param init The intialize tensor of first K steps. * \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp. * \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states. * \param state_placeholder The placeholder for the states.
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
*/ */
Array<Tensor> scan(Array<Tensor> init, Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan"); std::string name = "scan");
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
......
...@@ -139,12 +139,20 @@ inline bool TVMArgValue::IsNodeType() const { ...@@ -139,12 +139,20 @@ inline bool TVMArgValue::IsNodeType() const {
// extensions for TVMRetValue // extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=( inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) { const std::shared_ptr<Node>& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other); if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
}
return *this; return *this;
} }
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_); if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
}
return *this; return *this;
} }
......
...@@ -11,6 +11,7 @@ from ._ctypes._function import Function ...@@ -11,6 +11,7 @@ from ._ctypes._function import Function
from ._ctypes._function import _init_api_functions, register_func, get_global_func from ._ctypes._function import _init_api_functions, register_func, get_global_func
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import _base
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import tensor as _tensor from . import tensor as _tensor
...@@ -142,7 +143,7 @@ def compute(shape, fcompute, name="compute"): ...@@ -142,7 +143,7 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0) return op_node.output(0)
def scan(init, update, state_placeholder, name="scan"): def scan(init, update, state_placeholder, inputs=None, name="scan"):
"""Construct new tensors by scanning over axis. """Construct new tensors by scanning over axis.
Parameters Parameters
...@@ -156,6 +157,10 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -156,6 +157,10 @@ def scan(init, update, state_placeholder, name="scan"):
state_placeholder: Tensor or list of Tensor state_placeholder: Tensor or list of Tensor
The placeholder variables used by update. The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional name: str, optional
The name hint of the tensor The name hint of the tensor
...@@ -173,7 +178,7 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -173,7 +178,7 @@ def scan(init, update, state_placeholder, name="scan"):
s_state = tvm.placeholder((m, n)) s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state) res = tvm.scan(s_init, s_update, s_state, X)
""" """
if isinstance(init, _tensor.Tensor): if isinstance(init, _tensor.Tensor):
init = [init] init = [init]
...@@ -181,10 +186,14 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -181,10 +186,14 @@ def scan(init, update, state_placeholder, name="scan"):
update = [update] update = [update]
if isinstance(state_placeholder, _tensor.Tensor): if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder] state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder): if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length") raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3) axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder) op = _api_internal._ScanOp(name, axis, init, update, state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
...@@ -340,20 +349,25 @@ def _IterVar(dom, name, iter_type, thread_tag=''): ...@@ -340,20 +349,25 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
return _api_internal._IterVar(dom, var, iter_type, thread_tag) return _api_internal._IterVar(dom, var, iter_type, thread_tag)
def thread_axis(dom, tag, name=''): def thread_axis(dom=None, tag='', name=''):
"""Create a new IterVar to represent thread index. """Create a new IterVar to represent thread index.
Parameters Parameters
---------- ----------
dom : Range dom : Range or str
The domain of iteration. The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str tag : str, optional
The thread tag The thread tag
name : str, optional name : str, optional
The name of the var. The name of the var.
""" """
if isinstance(dom, _base.string_types):
tag, dom = dom, None
if len(tag) == 0:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag name = name if name else tag
return _IterVar(dom, name, 1, tag) return _IterVar(dom, name, 1, tag)
......
...@@ -41,6 +41,30 @@ class Schedule(NodeBase): ...@@ -41,6 +41,30 @@ class Schedule(NodeBase):
""" """
_api_internal._ScheduleNormalize(self) _api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
The operators between outputs and inputs are placed as member of group.
outputs are include in the group, while inputs are not included.
Parameters
----------
outputs : list of Tensors
The outputs of the group.
inputs : list of Tensors
The inputs of the group.
include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs.
"""
if isinstance(outputs, _tensor.Tensor):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _api_internal._ScheduleCreateGroup(
self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers): def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers. """Create a cache read of original tensor for readers.
...@@ -112,25 +136,7 @@ class Schedule(NodeBase): ...@@ -112,25 +136,7 @@ class Schedule(NodeBase):
@register_node @register_node
class Stage(NodeBase): class Stage(NodeBase):
"""A Stage represents schedule for one operation.""" """A Stage represents schedule for one operation."""
def rebase(self, parent, rebased): def split(self, parent, factor=None, nparts=None):
"""Rebase parent by an existing thread axis.
Parameters
----------
parent : IterVar
The parent iter var.
rebased : IterVar
The rebased iter var.
Returns
-------
rebased : IterVar
The rebased itervar.
"""
_api_internal._StageRebase(self, parent, rebased)
return rebased
def split(self, parent, factor=None, outer=None):
"""Split the stage either by factor providing outer scope, or both """Split the stage either by factor providing outer scope, or both
Parameters Parameters
...@@ -141,8 +147,8 @@ class Stage(NodeBase): ...@@ -141,8 +147,8 @@ class Stage(NodeBase):
factor : Expr, optional factor : Expr, optional
The splitting factor The splitting factor
outer : IterVar, optional nparts : Expr, optional
The outer split variable The number of outer parts.
Returns Returns
------- -------
...@@ -152,11 +158,13 @@ class Stage(NodeBase): ...@@ -152,11 +158,13 @@ class Stage(NodeBase):
inner : IterVar inner : IterVar
The inner variable of iteration. The inner variable of iteration.
""" """
if outer is not None: if nparts is not None:
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor) if factor is not None:
raise ValueError("Donot need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
else: else:
if factor is None: if factor is None:
raise ValueError("either outer or factor need to be provided") raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor) outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner return outer, inner
...@@ -188,8 +196,21 @@ class Stage(NodeBase): ...@@ -188,8 +196,21 @@ class Stage(NodeBase):
""" """
return _api_internal._StageSetScope(self, scope) return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads): def bind(self, ivar, thread_ivar):
"""Force launch threads at outermost scope of the stage. """Bind ivar to thread index thread_ivar
Parameters
----------
ivar : IterVar
The iteration to be binded to thread.
thread_ivar : IterVar
The thread to be binded.
"""
_api_internal._StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
Parameters Parameters
---------- ----------
...@@ -198,7 +219,7 @@ class Stage(NodeBase): ...@@ -198,7 +219,7 @@ class Stage(NodeBase):
""" """
if isinstance(threads, _collections.IterVar): if isinstance(threads, _collections.IterVar):
threads = [threads] threads = [threads]
_api_internal._StageOutermostThreads(self, threads) _api_internal._StageEnvThreads(self, threads)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the stage at parent's scope """Attach the stage at parent's scope
......
...@@ -182,7 +182,8 @@ TVM_REGISTER_API(_ScanOp) ...@@ -182,7 +182,8 @@ TVM_REGISTER_API(_ScanOp)
args[1], args[1],
args[2], args[2],
args[3], args[3],
args[4]); args[4],
args[5]);
}); });
TVM_REGISTER_API(_ExternOp) TVM_REGISTER_API(_ExternOp)
...@@ -219,27 +220,26 @@ TVM_REGISTER_API(_StageSetScope) ...@@ -219,27 +220,26 @@ TVM_REGISTER_API(_StageSetScope)
.set_scope(args[1]); .set_scope(args[1]);
}); });
TVM_REGISTER_API(_StageRebase) TVM_REGISTER_API(_StageBind)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
.rebase(args[1], args[2]); .bind(args[1], args[2]);
}); });
TVM_REGISTER_API(_StageSplitByFactor) TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner; IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
.split(args[1], &outer, &inner, args[2]); .split(args[1], args[2], &outer, &inner);
*ret = Array<IterVar>({outer, inner}); *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_StageSplitByOuter) TVM_REGISTER_API(_StageSplitByNParts)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar inner; IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
.split(args[1], args[2], &inner, args[3]); .split_by_nparts(args[1], args[2], &outer, &inner);
*ret = inner; *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_StageFuse) TVM_REGISTER_API(_StageFuse)
...@@ -278,15 +278,17 @@ TVM_REGISTER_API(_StageTile) ...@@ -278,15 +278,17 @@ TVM_REGISTER_API(_StageTile)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner; IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage() args[0].operator Stage()
.tile(args[1], args[2], &x_outer, &y_outer, .tile(args[1], args[2],
&x_inner, &y_inner, args[3], args[4]); args[3], args[4],
&x_outer, &y_outer,
&x_inner, &y_inner);
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_StageOutermostThreads) TVM_REGISTER_API(_StageEnvThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.outermost_threads(args[1]); .env_threads(args[1]);
}); });
TVM_REGISTER_API(_StageUnroll) TVM_REGISTER_API(_StageUnroll)
...@@ -313,6 +315,12 @@ TVM_REGISTER_API(_ScheduleNormalize) ...@@ -313,6 +315,12 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize(); .normalize();
}); });
TVM_REGISTER_API(_ScheduleCreateGroup)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheRead) TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
......
...@@ -34,9 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise) ...@@ -34,9 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath); REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis); REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps); REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule } // namespace schedule
......
...@@ -35,7 +35,6 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) { ...@@ -35,7 +35,6 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
} }
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) { std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
LOG(INFO) << "ssa get id";
if (name_alloc_map_.count(src)) return src; if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src); auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) { if (it != ssa_assign_map_.end()) {
......
...@@ -132,7 +132,6 @@ void ComputeOpNode::PropBoundToInputs( ...@@ -132,7 +132,6 @@ void ComputeOpNode::PropBoundToInputs(
void ComputeOpNode::GatherBound( void ComputeOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0)); const TensorDom& tdom = tensor_dom.at(self.output(0));
......
...@@ -99,7 +99,6 @@ void ExternOpNode::PropBoundToInputs( ...@@ -99,7 +99,6 @@ void ExternOpNode::PropBoundToInputs(
void ExternOpNode::GatherBound( void ExternOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
} }
......
...@@ -38,20 +38,29 @@ MakeLoopNest(const Stage& stage, ...@@ -38,20 +38,29 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = iv->var; value_map[iv] = iv->var;
continue; continue;
} }
// Bind iv could be another thread.
IterVar bind_iv = iv;
if (stage->iter_var_attrs.count(iv)) {
IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
if (bind_thread.defined()) bind_iv = bind_thread;
}
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
// initialize the offset and loop_level // initialize the offset and loop_level
Var var = iv->var; Var var = bind_iv->var;
if (new_loop_var) { if (new_loop_var) {
var = Var(iv->var->name_hint + ".init", iv->var.type()); var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
} }
// Mark the iter var in the IR, to remember the point // Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) { if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial; ForType for_type = ForType::Serial;
if (stage->iter_var_attrs.count(iv)) { if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) { switch (stage->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break; case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break; case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break; case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type" default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type << stage->iter_var_attrs[iv]->iter_type
<< " in the iter_var_attrs"; << " in the iter_var_attrs";
...@@ -67,7 +76,7 @@ MakeLoopNest(const Stage& stage, ...@@ -67,7 +76,7 @@ MakeLoopNest(const Stage& stage,
for_type, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
value_map[iv] = var; value_map[iv] = var;
} else { } else {
Var idx(iv->var->name_hint + ".idx", iv->var.type()); Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent, For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
...@@ -76,29 +85,29 @@ MakeLoopNest(const Stage& stage, ...@@ -76,29 +85,29 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op)); LetStmt::make(var, new_value, no_op));
} }
} else if (iv->thread_tag == "vthread") { } else if (bind_iv->thread_tag == "vthread") {
// virtual thread // virtual thread
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent)); CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var; value_map[iv] = var;
} else if (iv->thread_tag == "pipeline") { } else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker. // pipeline marker.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
CHECK(is_one(dom->extent)); CHECK(is_one(dom->extent));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pipeline_exec_scope, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min; value_map[iv] = dom->min;
} else { } else {
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (is_one(dom->extent)) { if (is_one(dom->extent)) {
value_map[iv] = dom->min; value_map[iv] = dom->min;
} else { } else {
......
...@@ -33,7 +33,6 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const { ...@@ -33,7 +33,6 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
return shape; return shape;
} }
Operation PlaceholderOpNode::make(std::string name, Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape, Array<Expr> shape,
Type dtype) { Type dtype) {
...@@ -66,7 +65,6 @@ void PlaceholderOpNode::PropBoundToInputs( ...@@ -66,7 +65,6 @@ void PlaceholderOpNode::PropBoundToInputs(
void PlaceholderOpNode::GatherBound( void PlaceholderOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
} }
......
...@@ -47,7 +47,8 @@ Operation ScanOpNode::make(std::string name, ...@@ -47,7 +47,8 @@ Operation ScanOpNode::make(std::string name,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder) { Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
auto n = std::make_shared<ScanOpNode>(); auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size()); CHECK_EQ(init.size(), state_placeholder.size());
...@@ -89,12 +90,14 @@ Operation ScanOpNode::make(std::string name, ...@@ -89,12 +90,14 @@ Operation ScanOpNode::make(std::string name,
n->init = init; n->init = init;
n->update = update; n->update = update;
n->state_placeholder = state_placeholder; n->state_placeholder = state_placeholder;
n->inputs = inputs;
return Operation(n); return Operation(n);
} }
Array<Tensor> scan(Array<Tensor> init, Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs,
std::string name) { std::string name) {
IterVar scan_axis = IterVar scan_axis =
IterVarNode::make( IterVarNode::make(
...@@ -102,7 +105,7 @@ Array<Tensor> scan(Array<Tensor> init, ...@@ -102,7 +105,7 @@ Array<Tensor> scan(Array<Tensor> init,
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered); Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make( Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder); name, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res; Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) { for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i)); res.push_back(op.output(i));
...@@ -179,7 +182,6 @@ void ScanOpNode::PropBoundToInputs( ...@@ -179,7 +182,6 @@ void ScanOpNode::PropBoundToInputs(
void ScanOpNode::GatherBound( void ScanOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
...@@ -200,8 +202,7 @@ void ScanOpNode::GatherBound( ...@@ -200,8 +202,7 @@ void ScanOpNode::GatherBound(
Range r = arith::Union(time_dom).cover_range(sdom); Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent( (*out_dom_map)[this->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(this, graph_ctx.feed_graph); Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self, body);
// Update for spatial axis. // Update for spatial axis.
size_t sp_idx = 0; size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) { for (size_t i = 0; i < output.size(); ++i) {
......
...@@ -15,10 +15,23 @@ ...@@ -15,10 +15,23 @@
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
// check if scope // check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { inline bool ScopeRelax(const IterVar& ivar,
const std::unordered_map<IterVar, IterVar>& bind_map,
const std::string& scope) {
using runtime::ThreadScope; using runtime::ThreadScope;
using runtime::StorageScope; using runtime::StorageScope;
auto it = bind_map.find(ivar);
IterVar iv = ivar;
if (it != bind_map.end()) {
iv = it->second;
}
if (iv->thread_tag.length() == 0) return false; if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false; if (scope.length() == 0) return false;
...@@ -28,10 +41,16 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { ...@@ -28,10 +41,16 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
void InferRootBound(const Stage& stage, void InferRootBound(const Stage& stage,
const GraphContext& ctx, const GraphContext& ctx,
const AttachPath& attach_path, const AttachPath& attach_path,
const std::unordered_map<IterVar, IterVar>& bind_map,
std::unordered_map<IterVar, Range>* rmap) { std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline) CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops"; << "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return; if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output) {
// verify correctness.
CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
<< "Output must be attached at root";
}
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) { if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : stage->op->root_iter_vars()) { for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
...@@ -42,8 +61,10 @@ void InferRootBound(const Stage& stage, ...@@ -42,8 +61,10 @@ void InferRootBound(const Stage& stage,
} }
// parent stage, if any // parent stage, if any
Stage parent; Stage parent;
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) { Stage attach_spec = stage.GetAttachSpec();
parent = stage->attach_stage; if (attach_spec->attach_type == kScope ||
attach_spec->attach_type == kScanUpdate) {
parent = attach_spec->attach_stage;
} }
// The tensor domain. // The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap; std::unordered_map<Tensor, TensorDom> tmap;
...@@ -72,13 +93,11 @@ void InferRootBound(const Stage& stage, ...@@ -72,13 +93,11 @@ void InferRootBound(const Stage& stage,
// from the already inferred bounds. // from the already inferred bounds.
std::unordered_map<const Variable*, IntSet> relax_set; std::unordered_map<const Variable*, IntSet> relax_set;
for (IterVar iv : attach_path.at(stage->op)) { for (IterVar iv : attach_path.at(stage->op)) {
if (ScopeRelax(iv, stage->scope)) { if (ScopeRelax(iv, bind_map, stage->scope)) {
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv)); relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
} }
} }
if (direct_consume_by_parent) { if (direct_consume_by_parent) {
// parent stage if exist
Stage parent = stage->attach_stage;
// Bound inference logics in parent. // Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state; std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true; bool fix_value = true;
...@@ -89,16 +108,16 @@ void InferRootBound(const Stage& stage, ...@@ -89,16 +108,16 @@ void InferRootBound(const Stage& stage,
CHECK(is_zero(vrange->min)) CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, " << "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. " << " call schedule.normalize to achieve this. "
<< " stage=" << parent; << " stage=" << parent << ", vrange=" << vrange->min;
// special optimization to remove trivial loop // special optimization to remove trivial loop
if (is_one(vrange->extent)) { if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min); up_state[iv] = IntSet::single_point(vrange->min);
} else if (fix_value && !ScopeRelax(iv, stage->scope)) { } else if (fix_value && !ScopeRelax(iv, bind_map, stage->scope)) {
up_state[iv] = IntSet::single_point(iv->var); up_state[iv] = IntSet::single_point(iv->var);
} else { } else {
up_state[iv] = IntSet::range(vrange); up_state[iv] = IntSet::range(vrange);
} }
if (stage->attach_ivar == iv) { if (attach_spec->attach_ivar == iv) {
fix_value = false; fix_value = false;
} }
} }
...@@ -159,7 +178,7 @@ void InferRootBound(const Stage& stage, ...@@ -159,7 +178,7 @@ void InferRootBound(const Stage& stage,
} }
op->PropBoundToInputs(op, dom_map, &tmap); op->PropBoundToInputs(op, dom_map, &tmap);
} }
stage->op->GatherBound(stage->op, ctx, tmap, rmap); stage->op->GatherBound(stage->op, tmap, rmap);
} }
Map<IterVar, Range> InferBound(const Schedule& sch) { Map<IterVar, Range> InferBound(const Schedule& sch) {
...@@ -167,18 +186,25 @@ Map<IterVar, Range> InferBound(const Schedule& sch) { ...@@ -167,18 +186,25 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (Operation op : sch->outputs) { for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op); roots.push_back(sch->stage_map[op]->op);
} }
std::unordered_map<IterVar, IterVar> bind_map;
for (Stage stage : sch->stages) {
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(!bind_map.count(kv.first));
bind_map[kv.first] = kv.second->bind_thread;
}
}
}
GraphContext ctx; GraphContext ctx;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots)); ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch); AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret; std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1]; const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, attach_path, &ret); InferRootBound(stage, ctx, attach_path, bind_map, &ret);
// pass down to get bound of all iter vars. // pass down to get bound of all iter vars.
PassDownDomain(stage, &ret); PassDownDomain(stage, &ret);
// setup outer most threads. for (IterVar iv : stage->env_threads) {
for (IterVar iv : stage->outermost_threads) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
ret[iv] = iv->dom; ret[iv] = iv->dom;
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map>
#include "./graph.h" #include "./graph.h"
namespace tvm { namespace tvm {
...@@ -82,6 +83,60 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) { ...@@ -82,6 +83,60 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
return rmap; return rmap;
} }
// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
bool GetSubGraphByPostDFS_(
const Operation& op,
const std::unordered_set<const Node*>& boundary,
bool include_bounary,
std::unordered_map<const Node*, bool>* visited,
Array<Operation>* result) {
if (visited->count(op.get())) {
return visited->at(op.get());
}
if (boundary.count(op.get())) {
(*visited)[op.get()] = true;
if (include_bounary) {
result->push_back(op);
}
return true;
}
// mark to avoid loop
// Not necessary for DAG.
(*visited)[op.get()] = false;
// check if we can reach boundary.
bool reach_boundary = false;
for (Tensor t : op->InputTensors()) {
if (GetSubGraphByPostDFS_(t->op, boundary,
include_bounary,
visited, result)) {
reach_boundary = true;
}
}
(*visited)[op.get()] = reach_boundary;
if (reach_boundary) {
result->push_back(op);
}
return reach_boundary;
}
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs) {
Array<Operation> result;
std::unordered_set<const Node*> boundary;
for (Tensor t : inputs) {
boundary.insert(t->op.get());
}
std::unordered_map<const Node*, bool> visited;
for (Tensor t : outputs) {
GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
&visited, &result);
}
return result;
}
void PostDFSOrder(const Operation& op, void PostDFSOrder(const Operation& op,
const ReadGraph& g, const ReadGraph& g,
std::unordered_set<Operation>* visited, std::unordered_set<Operation>* visited,
...@@ -118,30 +173,38 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) { ...@@ -118,30 +173,38 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) {
AttachPath CreateAttachPath(Schedule sch) { AttachPath CreateAttachPath(Schedule sch) {
AttachPath ret; AttachPath ret;
for (Stage stage : sch->stages) { for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) { std::unordered_set<const Node*> visited;
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
for (Stage stage : sch->stages) {
Array<IterVar> path; Array<IterVar> path;
for (Stage s = stage; s.defined();) {
for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) { CHECK(!visited.count(s.get()))
IterVar attach_ivar = s->attach_ivar; << "Find loop in compute_at attach group";
s = s->attach_stage; visited.insert(s.get());
bool start_attach = false; Stage spec = s.GetAttachSpec();
bool start_attach;
IterVar attach_ivar;
if (spec->attach_type == kScope) {
attach_ivar = spec->attach_ivar;
s = spec->attach_stage;
start_attach = false;
CHECK(attach_ivar.defined());
} else if (spec->attach_type == kScanUpdate) {
s = spec->attach_stage;
start_attach = true;
} else {
break;
}
CHECK(s.defined());
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1]; IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true; if (!start_attach && iv.same_as(attach_ivar)) {
start_attach = true;
}
if (start_attach) path.push_back(iv); if (start_attach) path.push_back(iv);
} }
CHECK(start_attach) CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar << "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op; << " in the schedule of " << s->op;
} }
if (!ret.count(stage->op)) { if (!ret.count(stage->op)) {
ret.Set(stage->op, path); ret.Set(stage->op, path);
} }
...@@ -203,53 +266,22 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -203,53 +266,22 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
return reach; return reach;
} }
// Get all the operations that forms body of scan Array<Operation> ScanGetBody(const Operation& scan_op) {
void ScanGetBodyPostDFS_( const ScanOpNode* scan = scan_op.as<ScanOpNode>();
Operation op, // Get the body.
const ScanOpNode* scan, Array<Tensor> inputs;
const FeedGraph& feed_graph,
std::unordered_set<const Node*>* visited,
Array<Operation>* result) {
if (op.get() == scan) return;
bool empty_feed = true;
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = feed_graph.find(op.output(i));
if (it != feed_graph.end() && it->second.size()) {
empty_feed = false;
for (const Operation& xop : it->second) {
if (visited->count(xop.get())) continue;
visited->insert(xop.get());
ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result);
result->push_back(xop);
}
}
}
if (empty_feed && op.get() != scan) {
LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan";
}
}
Array<Operation> ScanGetBody_(
const ScanOpNode* scan,
const FeedGraph& feed_graph) {
CHECK(scan != nullptr);
std::unordered_set<const Node*> visited;
Array<Operation> result;
for (Tensor t : scan->state_placeholder) { for (Tensor t : scan->state_placeholder) {
ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result); inputs.push_back(t);
} }
return result; for (Tensor t : scan->inputs) {
} inputs.push_back(t);
}
Array<Operation> ScanGetBody(const Operation& scan) { return GetSubGraph(scan->update, inputs, false);
return ScanGetBody_(scan.as<ScanOpNode>(),
CreateFeedGraph(CreateReadGraph({scan})));
} }
Map<IterVar, Expr> ScanFixPointAnalysis( Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
const Operation& scan_op, const Array<Operation>& body) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>(); const ScanOpNode* scan = scan_op.as<ScanOpNode>();
CHECK(body[0].get() == scan); Array<Operation> body = ScanGetBody(scan_op);
std::unordered_map<TensorDimKey, const Node*> exact_reach; std::unordered_map<TensorDimKey, const Node*> exact_reach;
std::unordered_set<const Node*> fail_set; std::unordered_set<const Node*> fail_set;
...@@ -276,8 +308,8 @@ Map<IterVar, Expr> ScanFixPointAnalysis( ...@@ -276,8 +308,8 @@ Map<IterVar, Expr> ScanFixPointAnalysis(
} }
}; };
// prop exact reach back. // prop exact reach back.
for (size_t i = body.size(); i != 1; --i) { for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i - 1]; const Operation& op = body[i];
if (op.as<ScanOpNode>()) { if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update; const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init; const auto& init = op.as<ScanOpNode>()->init;
......
...@@ -27,6 +27,11 @@ using ReadGraph = Map<Operation, Array<Tensor> >; ...@@ -27,6 +27,11 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >; using AttachPath = Map<Operation, Array<IterVar> >;
/*! /*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*!
* \brief Get read graph of each operation to all the * \brief Get read graph of each operation to all the
* Tensors that it directly depends on. * Tensors that it directly depends on.
* *
...@@ -37,6 +42,23 @@ using AttachPath = Map<Operation, Array<IterVar> >; ...@@ -37,6 +42,23 @@ using AttachPath = Map<Operation, Array<IterVar> >;
ReadGraph CreateReadGraph(const Array<Operation>& roots); ReadGraph CreateReadGraph(const Array<Operation>& roots);
/*! /*!
* \brief Get minimum subgraph between outputs and inputs.
* The operations contains node which input-reachable from any inputs
* output reachable to any outputs.
*
* The inputs won't be included in the subgraph, the outputs will be inclued.
*
* \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph.
* \param include_inputs Whether to include inputs
*
* \return The subgraph.
*/
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs);
/*!
* \brief Get a post DFS ordered of operations in the graph. * \brief Get a post DFS ordered of operations in the graph.
* \param roots The root of the graph. * \param roots The root of the graph.
* \param g The read graph. * \param g The read graph.
...@@ -67,14 +89,10 @@ AttachPath CreateAttachPath(Schedule sch); ...@@ -67,14 +89,10 @@ AttachPath CreateAttachPath(Schedule sch);
/*! /*!
* \brief Get all operations inside the recursion of scan. * \brief Get all operations inside the recursion of scan.
* \param scan The scan node. * \param scan_op The scan node ops.
* \param feed_graph The feed graph to help analysis.
* \return The body operations, in read dependency order. * \return The body operations, in read dependency order.
*/ */
Array<Operation> ScanGetBody_( Array<Operation> ScanGetBody(const Operation& scan_op);
const ScanOpNode* scan, const FeedGraph& feed_graph);
// same as ScanGetBody_, but create FeedGraph internally.
Array<Operation> ScanGetBody(const Operation& scan);
/*! /*!
* \brief Analyze each spatial dimension of scan's result. * \brief Analyze each spatial dimension of scan's result.
...@@ -85,11 +103,9 @@ Array<Operation> ScanGetBody(const Operation& scan); ...@@ -85,11 +103,9 @@ Array<Operation> ScanGetBody(const Operation& scan);
* next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...] * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
* *
* \param scan The scan node. * \param scan The scan node.
* \param body The body of scan, sorted in reverse PostDFSOrder.
* \return Map of spatial_axis -> IntImm * \return Map of spatial_axis -> IntImm
*/ */
Map<IterVar, Expr> ScanFixPointAnalysis( Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
const Operation& scan, const Array<Operation>& body);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
......
...@@ -22,6 +22,23 @@ inline bool prove_equal(Expr lhs, Expr rhs) { ...@@ -22,6 +22,23 @@ inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs)); return is_zero(ir::Simplify(lhs - rhs));
} }
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
} else {
bool match = is_zero(it->second->min);
if (!prove_equal(r->extent, it->second->extent)) match = false;
CHECK(match)
<< iv
<< " domain already inferred,"
<< " cannot prove their extents are the same "
<< it->second->extent << " vs " << r->extent;
}
}
void PassDownDomain(const Stage& stage, void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state, std::unordered_map<IterVar, Range>* p_state,
bool allow_missing) { bool allow_missing) {
...@@ -36,30 +53,15 @@ void PassDownDomain(const Stage& stage, ...@@ -36,30 +53,15 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner)); CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent); const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) { if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor); Update(p_state, r->inner, Range::make_with_min_extent(0, r->factor));
if (r->outer->dom.defined()) { Update(p_state, r->outer,
state[r->outer] = r->outer->dom; Range::make_with_min_extent(
} else { 0, DivCeil(range_parent->extent, r->factor)));
if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< r->outer
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same "
<< outer_ext << " vs " << outer_rng->extent;
}
}
} else { } else {
CHECK(r->outer->dom.defined()); Update(p_state, r->outer, Range::make_with_min_extent(0, r->nparts));
state[r->outer] = r->outer->dom; Update(p_state, r->inner,
state[r->inner] = Range::make_with_min_extent( Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent)); 0, DivCeil(range_parent->extent, r->nparts)));
} }
} else if (const FuseNode* r = rel.as<FuseNode>()) { } else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) { if (!state.count(r->outer) || !state.count(r->inner)) {
...@@ -75,20 +77,20 @@ void PassDownDomain(const Stage& stage, ...@@ -75,20 +77,20 @@ void PassDownDomain(const Stage& stage,
CHECK(allow_missing); CHECK(allow_missing);
continue; continue;
} }
Range res = Range::make_with_min_extent( Update(p_state, r->rebased,
0, state.at(r->parent)->extent); Range::make_with_min_extent(
if (r->rebased->dom.defined()) { 0, state.at(r->parent)->extent));
Range rebase_rng = r->rebased->dom;
bool match = is_zero(rebase_rng->min);
if (!prove_equal(rebase_rng->extent, res->extent)) match = false;
CHECK(match) << r->rebased
<< " does not match parent scope's range";
}
state[r->rebased] = res;
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
} }
// update the extents of binded threads.
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(state.count(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first));
}
}
} }
void PassUpIndex(const Stage& stage, void PassUpIndex(const Stage& stage,
......
...@@ -55,6 +55,7 @@ void ReplaceDataFlow(const Array<Stage>& stages, ...@@ -55,6 +55,7 @@ void ReplaceDataFlow(const Array<Stage>& stages,
Tensor Schedule::cache_read(const Tensor& tensor, Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope, const std::string& scope,
const Array<Operation>& readers) { const Array<Operation>& readers) {
(*this)->InvalidateCache();
// create identity mapping. // create identity mapping.
std::ostringstream os; std::ostringstream os;
os << tensor->op->name; os << tensor->op->name;
...@@ -81,18 +82,25 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -81,18 +82,25 @@ Tensor Schedule::cache_read(const Tensor& tensor,
} }
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite(); ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op)); Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages, op_stage);
Stage cache_stage = Stage(cache->op); Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope); cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size()); CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1, stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_); cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage); (*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache; return cache;
} }
Tensor Schedule::cache_write(const Tensor& tensor, Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) { const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op); Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute) CHECK(compute)
...@@ -123,7 +131,6 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -123,7 +131,6 @@ Tensor Schedule::cache_write(const Tensor& tensor,
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0); vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage // mutate orig stage
orig_stage->op = orig_new_op; orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
...@@ -137,6 +144,11 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -137,6 +144,11 @@ Tensor Schedule::cache_write(const Tensor& tensor,
stages->data.insert(stages->data.begin() + pos, stages->data.insert(stages->data.begin() + pos,
cache_stage.node_); cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage); (*this)->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache_tensor; return cache_tensor;
} }
...@@ -152,6 +164,11 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -152,6 +164,11 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
attach_mark[s.get()] = 1; attach_mark[s.get()] = 1;
} }
} }
for (Stage s : sch->groups) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue; if (!attach_mark.count(s.get())) continue;
...@@ -176,6 +193,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -176,6 +193,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
s->attach_ivar = rebase_map.at(s->attach_ivar); s->attach_ivar = rebase_map.at(s->attach_ivar);
} }
} }
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
} }
void SetScanAttach(const Schedule& sch) { // NOLINT(*) void SetScanAttach(const Schedule& sch) { // NOLINT(*)
...@@ -188,8 +211,8 @@ void SetScanAttach(const Schedule& sch) { // NOLINT(*) ...@@ -188,8 +211,8 @@ void SetScanAttach(const Schedule& sch) { // NOLINT(*)
} }
} }
void InjectInline(ScheduleNode* sch) {
void InjectInline(const Schedule& sch) { sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size()); std::vector<Expr> new_body(sch->stages.size());
// inline all the ops // inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
...@@ -241,12 +264,13 @@ void InjectInline(const Schedule& sch) { ...@@ -241,12 +264,13 @@ void InjectInline(const Schedule& sch) {
void Schedule::normalize() { void Schedule::normalize() {
RebaseNonZeroMinLoop(*this); RebaseNonZeroMinLoop(*this);
SetScanAttach(*this); SetScanAttach(*this);
InjectInline(*this); InjectInline(operator->());
} }
// Handle reduction factor. // Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor, Tensor Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) { const IterVar& axis) {
(*this)->InvalidateCache();
using ir::Reduce; using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce) CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis"; << "Can only factor reduction axis";
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./graph.h" #include "./graph.h"
#include "../op/op_util.h"
#include "../pass/ir_util.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -44,8 +46,9 @@ Stmt MakePipeline(const Stage& s, ...@@ -44,8 +46,9 @@ Stmt MakePipeline(const Stage& s,
class InjectAttach : public IRMutator { class InjectAttach : public IRMutator {
public: public:
InjectAttach(const Stage& stage, InjectAttach(const Stage& stage,
const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map) const std::unordered_map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {} : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
...@@ -53,10 +56,11 @@ class InjectAttach : public IRMutator { ...@@ -53,10 +56,11 @@ class InjectAttach : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr && if (op != nullptr &&
op->type_key == attr::loop_scope) { op->type_key == attr::loop_scope) {
CHECK_NE(producer_.size(), 0U); if (attach_spec_->attach_type == kScope &&
if (op->node == stage_->attach_ivar && op->node == attach_spec_->attach_ivar) {
producer_.back() == stage_->attach_stage->op.get()) { CHECK(!found_attach)
CHECK(!found_attach); << "Find IterVar" << attach_spec_->attach_ivar
<< " in multiple places in the IR";
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->type_key, op->value, op->node, op->type_key, op->value,
...@@ -65,26 +69,16 @@ class InjectAttach : public IRMutator { ...@@ -65,26 +69,16 @@ class InjectAttach : public IRMutator {
} }
return stmt; return stmt;
} }
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (op->is_producer) {
producer_.push_back(op->func.get());
Stmt ret = IRMutator::Mutate_(op, s);
producer_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
// whether attach point is found // whether attach point is found
bool found_attach{false}; bool found_attach{false};
private: private:
// the operations to be carried // The stage.
const Stage& stage_; const Stage& stage_;
// The attach spec, may not contain op.
const Stage& attach_spec_;
// domain map // domain map
const std::unordered_map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// internal stack about realization scope.
std::vector<const Node*> producer_;
}; };
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
...@@ -128,7 +122,6 @@ class InjectScanStep : public IRMutator { ...@@ -128,7 +122,6 @@ class InjectScanStep : public IRMutator {
bool is_init_; bool is_init_;
}; };
// Postprocessing of schedule op // Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer. // Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator { class SchedulePostProc : public IRMutator {
...@@ -157,9 +150,8 @@ class SchedulePostProc : public IRMutator { ...@@ -157,9 +150,8 @@ class SchedulePostProc : public IRMutator {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) { if (op->type_key == attr::loop_scope ||
return this->Mutate(op->body); op->type_key == attr::scan_init_scope) {
} else if (op->type_key == attr::scan_init_scope) {
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) { } else if (op->type_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>(); const ScanOpNode* scan = op->node.as<ScanOpNode>();
...@@ -237,6 +229,15 @@ class SchedulePostProc : public IRMutator { ...@@ -237,6 +229,15 @@ class SchedulePostProc : public IRMutator {
void Init(const Schedule& sch) { void Init(const Schedule& sch) {
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
for (auto kv : s->iter_var_attrs) {
// Update bind thread information.
if (kv.second->bind_thread.defined()) {
const Var& from = kv.first->var;
const Var& to = kv.second->bind_thread->var;
CHECK(!var_value_.count(from.get()));
var_value_[from.get()] = to;
}
}
// This must be checked for all ops, including scan. // This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) { if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0); Tensor target = s->origin_op.output(0);
...@@ -279,61 +280,67 @@ class SchedulePostProc : public IRMutator { ...@@ -279,61 +280,67 @@ class SchedulePostProc : public IRMutator {
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) { Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt(); Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
// scan init and scan updates // scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach; std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>(); const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue; if (!scan) continue;
for (Tensor t : scan->init) { for (Tensor t : scan->init) {
if (scan_attach.count(t->op)) { if (scan_init.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op)) CHECK(scan_init.at(t->op).same_as(s->op))
<< "Scan init tensor can only belong to one scan"; << "Scan init tensor can only belong to one scan";
} else { } else {
scan_attach[t->op] = std::make_pair(s->op, true); scan_init[t->op] = s->op;
}
}
for (Tensor t : scan->update) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan update tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, false);
} }
} }
} }
std::unordered_map<IterVar, Range> dom_map; // verify correctness of group.
for (auto kv : dom_map_) { for (Stage g : sch->groups) {
dom_map[kv.first] = kv.second; CHECK(!g->op.defined());
CHECK_EQ(g->leaf_iter_vars.size(), 0U);
} }
// reverse the post DFS order. // reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1]; Stage s = sch->stages[i - 1];
CHECK_NE(s->attach_type, kInline) CHECK_NE(s->attach_type, kInline)
<< "call schedule.normalize before scheduleops"; << "call schedule.normalize before scheduleops";
CHECK(s->op.defined());
// no need to specify place holder op. // no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue; if (s->op.as<PlaceholderOpNode>()) continue;
if (scan_attach.count(s->op)) { // Remove grouping sugar, get the real attach spec.
CHECK(s->attach_type == kNone || Stage attach_spec = s.GetAttachSpec();
s->attach_type == kScanUpdate)
<< "Cannot specify compute_at for scan's init/update"; if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined()); CHECK(body.defined());
const auto& p = scan_attach.at(s->op); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false);
InjectScanStep mu(s, p.first, dom_map, p.second);
body = mu.Mutate(body); body = mu.Mutate(body);
CHECK(mu.found_attach) CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update"; << "did not find attachment point for scan.update";
} else if (s->attach_type == kInlinedAlready) { } else if (attach_spec->attach_type == kInlinedAlready) {
// do nothing // do nothing
} else if (s->attach_type == kRoot || s-> attach_type == kNone) { } else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body); body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) { } else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined()); CHECK(body.defined());
InjectAttach mutator(s, dom_map); InjectAttach mutator(s, attach_spec, dom_map);
body = mutator.Mutate(body); body = mutator.Mutate(body);
CHECK(mutator.found_attach) CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in" << "did not find attachment point for " << s << " in "
<< s->attach_stage->op << " x " << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
<< ", body:\n"
<< body; << body;
} }
} }
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
namespace { namespace {
......
...@@ -11,11 +11,11 @@ def test_add(): ...@@ -11,11 +11,11 @@ def test_add():
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
# create iter var and assign them tags. # create iter var and assign them tags.
num_thread = 256 num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x") bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x)
_, x = s[C].split(x, factor=4) _, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x) s[C].vectorize(x)
# one line to build the function. # one line to build the function.
......
...@@ -22,31 +22,41 @@ def test_gemm(): ...@@ -22,31 +22,41 @@ def test_gemm():
scale = 8 scale = 8
num_thread = 8 num_thread = 8
block_factor = scale * num_thread block_factor = scale * num_thread
block_x = tvm.thread_axis(None, "blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis("threadIdx.x")
block_y = tvm.thread_axis(None, "blockIdx.y") block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") thread_y = tvm.thread_axis("threadIdx.y")
CC = s.cache_write(C, "local") CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC]) AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC]) BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y) by, yi = s[C].split(C.op.axis[0], factor=block_factor)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x) bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].reorder(block_y, block_x, yi, xi) s[C].reorder(by, bx, yi, xi)
_, yi = s[C].split(yi, outer=thread_y) s[C].bind(by, block_y)
_, xi = s[C].split(xi, outer=thread_x) s[C].bind(bx, block_x)
s[C].reorder(thread_y, thread_x, yi, xi) ty, yi = s[C].split(yi, nparts=num_thread)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].reorder(ty, tx, yi, xi)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
yo, xo = CC.op.axis yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo) s[CC].reorder(k, yo, xo)
s[CC].compute_at(s[C], thread_x)
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k) s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k) s[BB].compute_at(s[CC], k)
_, xi = s[AA].split(s[AA].op.axis[0], outer=thread_y) ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
_, xi = s[AA].split(xi, outer=thread_x) tx, xi = s[AA].split(xi, nparts=num_thread)
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y) s[AA].bind(ty, thread_y)
_, xi = s[BB].split(xi, outer=thread_x) s[AA].bind(tx, thread_x)
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
max_auto_unroll_step = 0 max_auto_unroll_step = 0
# lowering test # lowering test
...@@ -76,9 +86,9 @@ def test_gemm(): ...@@ -76,9 +86,9 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
if tvm.module.enabled("opencl"): if tvm.module.enabled("opencl"):
tvm.module.init_opencl() tvm.module.init_opencl()
check_device("cuda")
check_device("opencl") check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,10 +12,9 @@ def test_sum(): ...@@ -12,10 +12,9 @@ def test_sum():
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
# create iter var and assign them tags. # create iter var and assign them tags.
num_thread = 1 num_thread = 1
block_x = tvm.thread_axis(None, "blockIdx.x") xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x) s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
_, x = s[B].split(x, outer=thread_x)
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
...@@ -52,10 +51,9 @@ def test_rfactor(): ...@@ -52,10 +51,9 @@ def test_rfactor():
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n)) k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf) kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
s[BF].parallel(BF.op.axis[0]) s[BF].parallel(BF.op.axis[0])
# one line to build the function. # one line to build the function.
...@@ -88,16 +86,14 @@ def test_rfactor_threads(): ...@@ -88,16 +86,14 @@ def test_rfactor_threads():
k = tvm.reduce_axis((0, n)) k = tvm.reduce_axis((0, n))
nthread = 16 nthread = 16
B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B') B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
tx = tvm.thread_axis((0, nthread), "threadIdx.x")
ty = tvm.thread_axis((0, nthread), "threadIdx.y")
bx = tvm.thread_axis(None, "blockIdx.x")
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread) ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx) bx, tx = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].rebase(xi, ty) s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].rebase(s[B].op.reduce_axis[0], tx) s[B].bind(tx, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], tx) s[BF].compute_at(s[B], tx)
# one line to build the function. # one line to build the function.
...@@ -128,6 +124,6 @@ def test_rfactor_threads(): ...@@ -128,6 +124,6 @@ def test_rfactor_threads():
check_target("opencl") check_target("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_rfactor_threads()
test_rfactor() test_rfactor()
test_rfactor_threads()
test_sum() test_sum()
...@@ -15,10 +15,12 @@ def test_scan(): ...@@ -15,10 +15,12 @@ def test_scan():
num_thread = 256 num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x) xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
_, x = s[s_init].split(x, outer=thread_x) s[s_init].bind(xo, block_x)
_, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x) s[s_init].bind(xi, thread_x)
_, x = s[s_update].split(x, outer=thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
......
...@@ -11,10 +11,9 @@ def test_add_pipeline(): ...@@ -11,10 +11,9 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx # GPU schedule have to split by gridIdx and threadIdx
num_thread = 256 num_thread = 256
grid_x = tvm.thread_axis(None, "blockIdx.x") xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
_, x = s[C].split(x, outer=thread_x)
# compile to IR # compile to IR
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
......
"""Test group effect"""
import tvm
def test_scan_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
res = tvm.scan(s_init, s_update3, s_state, inputs=x)
s = tvm.Schedule(res.op)
assert s[s_update1].group is not None
assert s[s_update2].group == s[s_update1].group
# Assign within group, is valid
s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
# create a new group, for [s_update2 and s_update1]
g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
assert g2.group is not None
assert g2.group == s[s_update3].group
assert s[s_update2].group == g2
assert s[s_update1].group == g2
g2.compute_at(s[s_update3], s_update3.op.axis[1])
assert g2.attach_stage == s[s_update3]
try:
# compute outside group error.
s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
assert False
except tvm.TVMError:
pass
def test_compute_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x1].group == g
assert s[x].group == g
g.compute_at(s[x2], x2.op.axis[1])
assert g.attach_stage == s[x2]
assert g.num_child_stages == 2
def test_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x1, inputs=x)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert set(s.groups) == set([g1, g2])
assert s[x].group == g2
assert s[x1].group == g1
assert g1.group == g2
assert g2.num_child_stages == 2
assert g1.num_child_stages == 1
if __name__ == "__main__":
test_nest_group()
test_compute_group()
test_scan_group()
...@@ -29,8 +29,8 @@ def test_basic_pipeline(): ...@@ -29,8 +29,8 @@ def test_basic_pipeline():
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k) B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline") xo, xi = s[B].split(B.op.axis[0], nparts=1)
xo, xi = s[B].split(B.op.axis[0], outer=px) s[B].bind(xo, tvm.thread_axis("pipeline"))
xo, xi = s[B].split(xi, factor=4) xo, xi = s[B].split(xi, factor=4)
for S in stages: for S in stages:
s[S].compute_at(s[B], xo) s[S].compute_at(s[B], xo)
...@@ -50,8 +50,8 @@ def test_conv1d(): ...@@ -50,8 +50,8 @@ def test_conv1d():
return A[i-1] + A[i] + A[i+1] return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B') B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline") px, xi = s[B].split(B.op.axis[0], nparts=1)
xo, xi = s[B].split(B.op.axis[0], outer=px) s[B].bind(px, tvm.thread_axis("pipeline"))
s[A].compute_at(s[B], px) s[A].compute_at(s[B], px)
stmt = lower(s, [B]) stmt = lower(s, [B])
stmt = tvm.ir_pass.SplitPipeline(stmt, False) stmt = tvm.ir_pass.SplitPipeline(stmt, False)
......
...@@ -9,15 +9,15 @@ def test_storage_sync(): ...@@ -9,15 +9,15 @@ def test_storage_sync():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
block_x = tvm.thread_axis(None, "blockIdx.x") xo, xi = s[A2].split(A2.op.axis[0], factor=8)
xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x) s[A2].bind(xo, tvm.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared") s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
......
...@@ -7,8 +7,9 @@ def test_virtual_thread(): ...@@ -7,8 +7,9 @@ def test_virtual_thread():
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2') A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx") vx = tvm.thread_axis("vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx) xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
s[A2].bind(xo, vx)
xo, xi = s[A2].split(xi, 8) xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
......
...@@ -36,11 +36,10 @@ def test_bound3(): ...@@ -36,11 +36,10 @@ def test_bound3():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
s[A1].set_scope("shared") s[A1].set_scope("shared")
thread_x = tvm.thread_axis((0, 16), "threadIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], 32) xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, outer=thread_x) xi0, xi1 = s[A2].split(xi, nparts=16)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
yo, yi = s[A2].split(A2.op.axis[1], 16) yo, yi = s[A2].split(A2.op.axis[1], 16)
s[A2].reorder(xo, xi0, yo, xi1, yi) s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo) s[A1].compute_at(s[A2], yo)
...@@ -60,12 +59,10 @@ def test_bound_scan(): ...@@ -60,12 +59,10 @@ def test_bound_scan():
s_scan = tvm.scan(s_init, s_update, s_state) s_scan = tvm.scan(s_init, s_update, s_state)
assert tuple(s_scan.shape) == (m, n) assert tuple(s_scan.shape) == (m, n)
s = tvm.Schedule(s_scan.op) s = tvm.Schedule(s_scan.op)
XX = s.cache_read(X, "local", s_update) XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo) s[XX].compute_at(s[s_update], xo)
s.normalize() s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
...@@ -105,22 +102,59 @@ def test_bound_rfactor(): ...@@ -105,22 +102,59 @@ def test_bound_rfactor():
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n)) k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B') B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf) kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
s.normalize() s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4) assert(bounds[BF.op.axis[0]].extent.value == 4)
assert(bounds[BF.op.axis[1]].extent.value == 1) assert(bounds[BF.op.axis[1]].extent.value == 1)
def test_bound_group_schedule():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
g.compute_at(s[x2], x2.op.axis[0])
assert s[x1].group == g
assert s[x].group == g
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent == n
def test_bound_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x, inputs=x, include_inputs=True)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x].group == g1
assert s[x1].group == g2
g2.compute_at(s[x2], x2.op.axis[0])
g1.compute_at(s[x1], s[x1].op.axis[1])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent.value == 1
assert bounds[x1.op.axis[0]].extent.value == 1
assert bounds[x1.op.axis[1]].extent == n
if __name__ == "__main__": if __name__ == "__main__":
test_bound_nest_group()
test_bound_group_schedule()
test_bound_scan()
test_bound3()
test_bound_rfactor() test_bound_rfactor()
test_bound_blur() test_bound_blur()
test_bound_conv1d() test_bound_conv1d()
test_bound_scan()
test_bound3()
test_bound1() test_bound1()
test_bound2() test_bound2()
...@@ -59,7 +59,7 @@ def test_scan_fix_point(): ...@@ -59,7 +59,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state) s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op) body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
...@@ -69,8 +69,7 @@ def test_scan_fix_point(): ...@@ -69,8 +69,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n), s_update = tvm.compute((l, m, n),
lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state) s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
...@@ -89,7 +88,7 @@ def test_scan_fix_point(): ...@@ -89,7 +88,7 @@ def test_scan_fix_point():
[s1_update, s2_update], [s1_update, s2_update],
[s1, s2]) [s1, s2])
body = tvm.schedule.ScanGetBody(r0.op) body = tvm.schedule.ScanGetBody(r0.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op, body) fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op)
assert(fxpt[r1.op.spatial_axis_[0]].value == 1) assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
test_scan0() test_scan0()
......
...@@ -35,8 +35,8 @@ def test_add_pipeline(): ...@@ -35,8 +35,8 @@ def test_add_pipeline():
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
grid_x = tvm.thread_axis((0, 1), "pipeline") px, x = s[C].split(C.op.axis[0], nparts=1)
_, x = s[C].split(C.op.axis[0], outer=grid_x) s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd") fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi) fsplits = tvm.ir_pass.SplitHostDevice(fapi)
print(fsplits[1].body) print(fsplits[1].body)
......
...@@ -57,9 +57,9 @@ def test_buffer_linebuff(): ...@@ -57,9 +57,9 @@ def test_buffer_linebuff():
# correctness checks # correctness checks
if (read_data_valid.get_int()): if (read_data_valid.get_int()):
# Derive convolution window indices # Derive convolution window indices
baseIdx = read_idx/(kernel_width*kernel_width) baseIdx = read_idx // (kernel_width*kernel_width)
offsetIdx = read_idx%(kernel_width*kernel_width) offsetIdx = read_idx % (kernel_width*kernel_width)
yOffset = offsetIdx/kernel_width yOffset = offsetIdx // kernel_width
xOffset = offsetIdx%kernel_width xOffset = offsetIdx%kernel_width
pixIndex = baseIdx + yOffset * window_width + xOffset pixIndex = baseIdx + yOffset * window_width + xOffset
assert(read_data.get_int()==test_data[pixIndex]) assert(read_data.get_int()==test_data[pixIndex])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment