Commit 8f51c5fd by Tianqi Chen Committed by GitHub

[SCHEDULE] Add group, refactor thread bind api. (#82)

* [SCHEDULE] Add group, refactor thread bind api.

* fix doc

* fix g++-4.8

* More testscase

* Remove graph context from fix pt analysis
parent 6268e183
Subproject commit ce80d58741688b200f498fed8c7b0ea33e0516c8
Subproject commit 59fdca16978b6184bab87fbff7a00c95f1804686
......@@ -32,17 +32,6 @@ struct TensorDom {
};
/*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public FunctionBaseNode {
......@@ -102,13 +91,11 @@ class OperationNode : public FunctionBaseNode {
* Set the range of each root_iter_vars in the op to out_dom_map
*
* \param self The reference to self.
* \param graph_ctx The global graph context information.
* \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted.
*/
virtual void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*!
......@@ -162,7 +149,6 @@ class PlaceholderOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -214,7 +200,6 @@ class ComputeOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -253,6 +238,11 @@ class ScanOpNode : public OperationNode {
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief the inputs to the scan, these are optionally provided
* But they can be helpful to provide hints to speedup get of scan body.
*/
Array<Tensor> inputs;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
......@@ -279,7 +269,6 @@ class ScanOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -296,13 +285,15 @@ class ScanOpNode : public OperationNode {
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder);
Array<Tensor> state_placeholder,
Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
......@@ -339,7 +330,6 @@ class ExternOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
......@@ -388,16 +378,19 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
* \brief Construct new tensors by scan.
*
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
*/
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan");
// same as compute, specialized for different fcompute function
......
......@@ -139,12 +139,20 @@ inline bool TVMArgValue::IsNodeType() const {
// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) {
if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
}
return *this;
}
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
}
return *this;
}
......
......@@ -11,6 +11,7 @@ from ._ctypes._function import Function
from ._ctypes._function import _init_api_functions, register_func, get_global_func
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import _base
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
......@@ -142,7 +143,7 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0)
def scan(init, update, state_placeholder, name="scan"):
def scan(init, update, state_placeholder, inputs=None, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
......@@ -156,6 +157,10 @@ def scan(init, update, state_placeholder, name="scan"):
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional
The name hint of the tensor
......@@ -173,7 +178,7 @@ def scan(init, update, state_placeholder, name="scan"):
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
res = tvm.scan(s_init, s_update, s_state, X)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
......@@ -181,10 +186,14 @@ def scan(init, update, state_placeholder, name="scan"):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
......@@ -340,20 +349,25 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
return _api_internal._IterVar(dom, var, iter_type, thread_tag)
def thread_axis(dom, tag, name=''):
def thread_axis(dom=None, tag='', name=''):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range
The domain of iteration.
dom : Range or str
The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str
tag : str, optional
The thread tag
name : str, optional
The name of the var.
"""
if isinstance(dom, _base.string_types):
tag, dom = dom, None
if len(tag) == 0:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return _IterVar(dom, name, 1, tag)
......
......@@ -41,6 +41,30 @@ class Schedule(NodeBase):
"""
_api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
The operators between outputs and inputs are placed as member of group.
outputs are include in the group, while inputs are not included.
Parameters
----------
outputs : list of Tensors
The outputs of the group.
inputs : list of Tensors
The inputs of the group.
include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs.
"""
if isinstance(outputs, _tensor.Tensor):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _api_internal._ScheduleCreateGroup(
self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
......@@ -112,25 +136,7 @@ class Schedule(NodeBase):
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def rebase(self, parent, rebased):
"""Rebase parent by an existing thread axis.
Parameters
----------
parent : IterVar
The parent iter var.
rebased : IterVar
The rebased iter var.
Returns
-------
rebased : IterVar
The rebased itervar.
"""
_api_internal._StageRebase(self, parent, rebased)
return rebased
def split(self, parent, factor=None, outer=None):
def split(self, parent, factor=None, nparts=None):
"""Split the stage either by factor providing outer scope, or both
Parameters
......@@ -141,8 +147,8 @@ class Stage(NodeBase):
factor : Expr, optional
The splitting factor
outer : IterVar, optional
The outer split variable
nparts : Expr, optional
The number of outer parts.
Returns
-------
......@@ -152,11 +158,13 @@ class Stage(NodeBase):
inner : IterVar
The inner variable of iteration.
"""
if outer is not None:
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
if nparts is not None:
if factor is not None:
raise ValueError("Donot need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
else:
if factor is None:
raise ValueError("either outer or factor need to be provided")
raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
......@@ -188,8 +196,21 @@ class Stage(NodeBase):
"""
return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.
def bind(self, ivar, thread_ivar):
"""Bind ivar to thread index thread_ivar
Parameters
----------
ivar : IterVar
The iteration to be binded to thread.
thread_ivar : IterVar
The thread to be binded.
"""
_api_internal._StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
Parameters
----------
......@@ -198,7 +219,7 @@ class Stage(NodeBase):
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)
_api_internal._StageEnvThreads(self, threads)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
......
......@@ -182,7 +182,8 @@ TVM_REGISTER_API(_ScanOp)
args[1],
args[2],
args[3],
args[4]);
args[4],
args[5]);
});
TVM_REGISTER_API(_ExternOp)
......@@ -219,27 +220,26 @@ TVM_REGISTER_API(_StageSetScope)
.set_scope(args[1]);
});
TVM_REGISTER_API(_StageRebase)
TVM_REGISTER_API(_StageBind)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.rebase(args[1], args[2]);
.bind(args[1], args[2]);
});
TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.split(args[1], &outer, &inner, args[2]);
.split(args[1], args[2], &outer, &inner);
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageSplitByOuter)
TVM_REGISTER_API(_StageSplitByNParts)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar inner;
IterVar outer, inner;
args[0].operator Stage()
.split(args[1], args[2], &inner, args[3]);
*ret = inner;
.split_by_nparts(args[1], args[2], &outer, &inner);
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageFuse)
......@@ -278,15 +278,17 @@ TVM_REGISTER_API(_StageTile)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage()
.tile(args[1], args[2], &x_outer, &y_outer,
&x_inner, &y_inner, args[3], args[4]);
.tile(args[1], args[2],
args[3], args[4],
&x_outer, &y_outer,
&x_inner, &y_inner);
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_StageOutermostThreads)
TVM_REGISTER_API(_StageEnvThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
.env_threads(args[1]);
});
TVM_REGISTER_API(_StageUnroll)
......@@ -313,6 +315,12 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});
TVM_REGISTER_API(_ScheduleCreateGroup)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
......
......@@ -34,9 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule
......
......@@ -35,7 +35,6 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
}
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
LOG(INFO) << "ssa get id";
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
......
......@@ -132,7 +132,6 @@ void ComputeOpNode::PropBoundToInputs(
void ComputeOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
......
......@@ -99,7 +99,6 @@ void ExternOpNode::PropBoundToInputs(
void ExternOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
......
......@@ -38,20 +38,29 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = iv->var;
continue;
}
// Bind iv could be another thread.
IterVar bind_iv = iv;
if (stage->iter_var_attrs.count(iv)) {
IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
if (bind_thread.defined()) bind_iv = bind_thread;
}
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
Var var = iv->var;
Var var = bind_iv->var;
if (new_loop_var) {
var = Var(iv->var->name_hint + ".init", iv->var.type());
var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
}
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type
<< " in the iter_var_attrs";
......@@ -67,7 +76,7 @@ MakeLoopNest(const Stage& stage,
for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(iv->var->name_hint + ".idx", iv->var.type());
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
......@@ -76,29 +85,29 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
} else if (iv->thread_tag == "vthread") {
} else if (bind_iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else if (iv->thread_tag == "pipeline") {
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
CHECK(is_zero(dom->min));
CHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
......
......@@ -33,7 +33,6 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
......@@ -66,7 +65,6 @@ void PlaceholderOpNode::PropBoundToInputs(
void PlaceholderOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
......
......@@ -47,7 +47,8 @@ Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
......@@ -89,12 +90,14 @@ Operation ScanOpNode::make(std::string name,
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
n->inputs = inputs;
return Operation(n);
}
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs,
std::string name) {
IterVar scan_axis =
IterVarNode::make(
......@@ -102,7 +105,7 @@ Array<Tensor> scan(Array<Tensor> init,
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
name, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
......@@ -179,7 +182,6 @@ void ScanOpNode::PropBoundToInputs(
void ScanOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
......@@ -200,8 +202,7 @@ void ScanOpNode::GatherBound(
Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(this, graph_ctx.feed_graph);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self, body);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
......
......@@ -15,10 +15,23 @@
namespace tvm {
namespace schedule {
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
// check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
inline bool ScopeRelax(const IterVar& ivar,
const std::unordered_map<IterVar, IterVar>& bind_map,
const std::string& scope) {
using runtime::ThreadScope;
using runtime::StorageScope;
auto it = bind_map.find(ivar);
IterVar iv = ivar;
if (it != bind_map.end()) {
iv = it->second;
}
if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false;
......@@ -28,10 +41,16 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
void InferRootBound(const Stage& stage,
const GraphContext& ctx,
const AttachPath& attach_path,
const std::unordered_map<IterVar, IterVar>& bind_map,
std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output) {
// verify correctness.
CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
<< "Output must be attached at root";
}
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined());
......@@ -42,8 +61,10 @@ void InferRootBound(const Stage& stage,
}
// parent stage, if any
Stage parent;
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) {
parent = stage->attach_stage;
Stage attach_spec = stage.GetAttachSpec();
if (attach_spec->attach_type == kScope ||
attach_spec->attach_type == kScanUpdate) {
parent = attach_spec->attach_stage;
}
// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
......@@ -72,13 +93,11 @@ void InferRootBound(const Stage& stage,
// from the already inferred bounds.
std::unordered_map<const Variable*, IntSet> relax_set;
for (IterVar iv : attach_path.at(stage->op)) {
if (ScopeRelax(iv, stage->scope)) {
if (ScopeRelax(iv, bind_map, stage->scope)) {
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
}
}
if (direct_consume_by_parent) {
// parent stage if exist
Stage parent = stage->attach_stage;
// Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
......@@ -89,16 +108,16 @@ void InferRootBound(const Stage& stage,
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. "
<< " stage=" << parent;
<< " stage=" << parent << ", vrange=" << vrange->min;
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
} else if (fix_value && !ScopeRelax(iv, stage->scope)) {
} else if (fix_value && !ScopeRelax(iv, bind_map, stage->scope)) {
up_state[iv] = IntSet::single_point(iv->var);
} else {
up_state[iv] = IntSet::range(vrange);
}
if (stage->attach_ivar == iv) {
if (attach_spec->attach_ivar == iv) {
fix_value = false;
}
}
......@@ -159,7 +178,7 @@ void InferRootBound(const Stage& stage,
}
op->PropBoundToInputs(op, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, ctx, tmap, rmap);
stage->op->GatherBound(stage->op, tmap, rmap);
}
Map<IterVar, Range> InferBound(const Schedule& sch) {
......@@ -167,18 +186,25 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
std::unordered_map<IterVar, IterVar> bind_map;
for (Stage stage : sch->stages) {
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(!bind_map.count(kv.first));
bind_map[kv.first] = kv.second->bind_thread;
}
}
}
GraphContext ctx;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, attach_path, &ret);
InferRootBound(stage, ctx, attach_path, bind_map, &ret);
// pass down to get bound of all iter vars.
PassDownDomain(stage, &ret);
// setup outer most threads.
for (IterVar iv : stage->outermost_threads) {
for (IterVar iv : stage->env_threads) {
CHECK(iv->dom.defined());
ret[iv] = iv->dom;
}
......
......@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <unordered_set>
#include <unordered_map>
#include "./graph.h"
namespace tvm {
......@@ -82,6 +83,60 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
return rmap;
}
// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
bool GetSubGraphByPostDFS_(
const Operation& op,
const std::unordered_set<const Node*>& boundary,
bool include_bounary,
std::unordered_map<const Node*, bool>* visited,
Array<Operation>* result) {
if (visited->count(op.get())) {
return visited->at(op.get());
}
if (boundary.count(op.get())) {
(*visited)[op.get()] = true;
if (include_bounary) {
result->push_back(op);
}
return true;
}
// mark to avoid loop
// Not necessary for DAG.
(*visited)[op.get()] = false;
// check if we can reach boundary.
bool reach_boundary = false;
for (Tensor t : op->InputTensors()) {
if (GetSubGraphByPostDFS_(t->op, boundary,
include_bounary,
visited, result)) {
reach_boundary = true;
}
}
(*visited)[op.get()] = reach_boundary;
if (reach_boundary) {
result->push_back(op);
}
return reach_boundary;
}
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs) {
Array<Operation> result;
std::unordered_set<const Node*> boundary;
for (Tensor t : inputs) {
boundary.insert(t->op.get());
}
std::unordered_map<const Node*, bool> visited;
for (Tensor t : outputs) {
GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
&visited, &result);
}
return result;
}
void PostDFSOrder(const Operation& op,
const ReadGraph& g,
std::unordered_set<Operation>* visited,
......@@ -118,30 +173,38 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) {
AttachPath CreateAttachPath(Schedule sch) {
AttachPath ret;
for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
for (Stage stage : sch->stages) {
std::unordered_set<const Node*> visited;
Array<IterVar> path;
for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) {
IterVar attach_ivar = s->attach_ivar;
s = s->attach_stage;
bool start_attach = false;
for (Stage s = stage; s.defined();) {
CHECK(!visited.count(s.get()))
<< "Find loop in compute_at attach group";
visited.insert(s.get());
Stage spec = s.GetAttachSpec();
bool start_attach;
IterVar attach_ivar;
if (spec->attach_type == kScope) {
attach_ivar = spec->attach_ivar;
s = spec->attach_stage;
start_attach = false;
CHECK(attach_ivar.defined());
} else if (spec->attach_type == kScanUpdate) {
s = spec->attach_stage;
start_attach = true;
} else {
break;
}
CHECK(s.defined());
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true;
if (!start_attach && iv.same_as(attach_ivar)) {
start_attach = true;
}
if (start_attach) path.push_back(iv);
}
CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op;
}
if (!ret.count(stage->op)) {
ret.Set(stage->op, path);
}
......@@ -203,53 +266,22 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
return reach;
}
// Get all the operations that forms body of scan
void ScanGetBodyPostDFS_(
Operation op,
const ScanOpNode* scan,
const FeedGraph& feed_graph,
std::unordered_set<const Node*>* visited,
Array<Operation>* result) {
if (op.get() == scan) return;
bool empty_feed = true;
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = feed_graph.find(op.output(i));
if (it != feed_graph.end() && it->second.size()) {
empty_feed = false;
for (const Operation& xop : it->second) {
if (visited->count(xop.get())) continue;
visited->insert(xop.get());
ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result);
result->push_back(xop);
}
}
}
if (empty_feed && op.get() != scan) {
LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan";
}
}
Array<Operation> ScanGetBody_(
const ScanOpNode* scan,
const FeedGraph& feed_graph) {
CHECK(scan != nullptr);
std::unordered_set<const Node*> visited;
Array<Operation> result;
Array<Operation> ScanGetBody(const Operation& scan_op) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>();
// Get the body.
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result);
inputs.push_back(t);
}
return result;
}
Array<Operation> ScanGetBody(const Operation& scan) {
return ScanGetBody_(scan.as<ScanOpNode>(),
CreateFeedGraph(CreateReadGraph({scan})));
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
return GetSubGraph(scan->update, inputs, false);
}
Map<IterVar, Expr> ScanFixPointAnalysis(
const Operation& scan_op, const Array<Operation>& body) {
Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>();
CHECK(body[0].get() == scan);
Array<Operation> body = ScanGetBody(scan_op);
std::unordered_map<TensorDimKey, const Node*> exact_reach;
std::unordered_set<const Node*> fail_set;
......@@ -276,8 +308,8 @@ Map<IterVar, Expr> ScanFixPointAnalysis(
}
};
// prop exact reach back.
for (size_t i = body.size(); i != 1; --i) {
const Operation& op = body[i - 1];
for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i];
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
......
......@@ -27,6 +27,11 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >;
/*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*!
* \brief Get read graph of each operation to all the
* Tensors that it directly depends on.
*
......@@ -37,6 +42,23 @@ using AttachPath = Map<Operation, Array<IterVar> >;
ReadGraph CreateReadGraph(const Array<Operation>& roots);
/*!
* \brief Get minimum subgraph between outputs and inputs.
* The operations contains node which input-reachable from any inputs
* output reachable to any outputs.
*
* The inputs won't be included in the subgraph, the outputs will be inclued.
*
* \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph.
* \param include_inputs Whether to include inputs
*
* \return The subgraph.
*/
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs);
/*!
* \brief Get a post DFS ordered of operations in the graph.
* \param roots The root of the graph.
* \param g The read graph.
......@@ -67,14 +89,10 @@ AttachPath CreateAttachPath(Schedule sch);
/*!
* \brief Get all operations inside the recursion of scan.
* \param scan The scan node.
* \param feed_graph The feed graph to help analysis.
* \param scan_op The scan node ops.
* \return The body operations, in read dependency order.
*/
Array<Operation> ScanGetBody_(
const ScanOpNode* scan, const FeedGraph& feed_graph);
// same as ScanGetBody_, but create FeedGraph internally.
Array<Operation> ScanGetBody(const Operation& scan);
Array<Operation> ScanGetBody(const Operation& scan_op);
/*!
* \brief Analyze each spatial dimension of scan's result.
......@@ -85,11 +103,9 @@ Array<Operation> ScanGetBody(const Operation& scan);
* next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
*
* \param scan The scan node.
* \param body The body of scan, sorted in reverse PostDFSOrder.
* \return Map of spatial_axis -> IntImm
*/
Map<IterVar, Expr> ScanFixPointAnalysis(
const Operation& scan, const Array<Operation>& body);
Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
} // namespace schedule
} // namespace tvm
......
......@@ -22,6 +22,23 @@ inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
} else {
bool match = is_zero(it->second->min);
if (!prove_equal(r->extent, it->second->extent)) match = false;
CHECK(match)
<< iv
<< " domain already inferred,"
<< " cannot prove their extents are the same "
<< it->second->extent << " vs " << r->extent;
}
}
void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
bool allow_missing) {
......@@ -36,30 +53,15 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor);
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
Update(p_state, r->inner, Range::make_with_min_extent(0, r->factor));
Update(p_state, r->outer,
Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor)));
} else {
if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< r->outer
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same "
<< outer_ext << " vs " << outer_rng->extent;
}
}
} else {
CHECK(r->outer->dom.defined());
state[r->outer] = r->outer->dom;
state[r->inner] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent));
Update(p_state, r->outer, Range::make_with_min_extent(0, r->nparts));
Update(p_state, r->inner,
Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->nparts)));
}
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) {
......@@ -75,20 +77,20 @@ void PassDownDomain(const Stage& stage,
CHECK(allow_missing);
continue;
}
Range res = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
if (r->rebased->dom.defined()) {
Range rebase_rng = r->rebased->dom;
bool match = is_zero(rebase_rng->min);
if (!prove_equal(rebase_rng->extent, res->extent)) match = false;
CHECK(match) << r->rebased
<< " does not match parent scope's range";
}
state[r->rebased] = res;
Update(p_state, r->rebased,
Range::make_with_min_extent(
0, state.at(r->parent)->extent));
} else {
LOG(FATAL) << "unknown relation type";
}
}
// update the extents of binded threads.
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(state.count(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first));
}
}
}
void PassUpIndex(const Stage& stage,
......
......@@ -55,6 +55,7 @@ void ReplaceDataFlow(const Array<Stage>& stages,
Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers) {
(*this)->InvalidateCache();
// create identity mapping.
std::ostringstream os;
os << tensor->op->name;
......@@ -81,18 +82,25 @@ Tensor Schedule::cache_read(const Tensor& tensor,
}
ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op));
Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages, op_stage);
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache;
}
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
......@@ -123,7 +131,6 @@ Tensor Schedule::cache_write(const Tensor& tensor,
std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
......@@ -137,6 +144,11 @@ Tensor Schedule::cache_write(const Tensor& tensor,
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache_tensor;
}
......@@ -152,6 +164,11 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
attach_mark[s.get()] = 1;
}
}
for (Stage s : sch->groups) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue;
......@@ -176,6 +193,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
void SetScanAttach(const Schedule& sch) { // NOLINT(*)
......@@ -188,8 +211,8 @@ void SetScanAttach(const Schedule& sch) { // NOLINT(*)
}
}
void InjectInline(const Schedule& sch) {
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size());
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
......@@ -241,12 +264,13 @@ void InjectInline(const Schedule& sch) {
void Schedule::normalize() {
RebaseNonZeroMinLoop(*this);
SetScanAttach(*this);
InjectInline(*this);
InjectInline(operator->());
}
// Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
(*this)->InvalidateCache();
using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis";
......
......@@ -12,6 +12,8 @@
#include <unordered_map>
#include <unordered_set>
#include "./graph.h"
#include "../op/op_util.h"
#include "../pass/ir_util.h"
namespace tvm {
namespace schedule {
......@@ -44,8 +46,9 @@ Stmt MakePipeline(const Stage& s,
class InjectAttach : public IRMutator {
public:
InjectAttach(const Stage& stage,
const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {}
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
......@@ -53,10 +56,11 @@ class InjectAttach : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::loop_scope) {
CHECK_NE(producer_.size(), 0U);
if (op->node == stage_->attach_ivar &&
producer_.back() == stage_->attach_stage->op.get()) {
CHECK(!found_attach);
if (attach_spec_->attach_type == kScope &&
op->node == attach_spec_->attach_ivar) {
CHECK(!found_attach)
<< "Find IterVar" << attach_spec_->attach_ivar
<< " in multiple places in the IR";
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
......@@ -65,26 +69,16 @@ class InjectAttach : public IRMutator {
}
return stmt;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (op->is_producer) {
producer_.push_back(op->func.get());
Stmt ret = IRMutator::Mutate_(op, s);
producer_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
// whether attach point is found
bool found_attach{false};
private:
// the operations to be carried
// The stage.
const Stage& stage_;
// The attach spec, may not contain op.
const Stage& attach_spec_;
// domain map
const std::unordered_map<IterVar, Range>& dom_map_;
// internal stack about realization scope.
std::vector<const Node*> producer_;
};
// inject the operator's realization on the stmt.
......@@ -128,7 +122,6 @@ class InjectScanStep : public IRMutator {
bool is_init_;
};
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
......@@ -157,9 +150,8 @@ class SchedulePostProc : public IRMutator {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_init_scope) {
if (op->type_key == attr::loop_scope ||
op->type_key == attr::scan_init_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
......@@ -237,6 +229,15 @@ class SchedulePostProc : public IRMutator {
void Init(const Schedule& sch) {
for (Stage s : sch->stages) {
for (auto kv : s->iter_var_attrs) {
// Update bind thread information.
if (kv.second->bind_thread.defined()) {
const Var& from = kv.first->var;
const Var& to = kv.second->bind_thread->var;
CHECK(!var_value_.count(from.get()));
var_value_[from.get()] = to;
}
}
// This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0);
......@@ -279,61 +280,67 @@ class SchedulePostProc : public IRMutator {
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
// scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach;
std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (Tensor t : scan->init) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
if (scan_init.count(t->op)) {
CHECK(scan_init.at(t->op).same_as(s->op))
<< "Scan init tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, true);
}
}
for (Tensor t : scan->update) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan update tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, false);
scan_init[t->op] = s->op;
}
}
}
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
// verify correctness of group.
for (Stage g : sch->groups) {
CHECK(!g->op.defined());
CHECK_EQ(g->leaf_iter_vars.size(), 0U);
}
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
CHECK_NE(s->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
CHECK(s->op.defined());
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
if (scan_attach.count(s->op)) {
CHECK(s->attach_type == kNone ||
s->attach_type == kScanUpdate)
<< "Cannot specify compute_at for scan's init/update";
// Remove grouping sugar, get the real attach spec.
Stage attach_spec = s.GetAttachSpec();
if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
const auto& p = scan_attach.at(s->op);
InjectScanStep mu(s, p.first, dom_map, p.second);
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update";
} else if (s->attach_type == kInlinedAlready) {
<< "did not find attachment point for scan.update";
} else if (attach_spec->attach_type == kInlinedAlready) {
// do nothing
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
} else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
} else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
InjectAttach mutator(s, dom_map);
InjectAttach mutator(s, attach_spec, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in"
<< s->attach_stage->op << " x "
<< "did not find attachment point for " << s << " in "
<< attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
<< ", body:\n"
<< body;
}
}
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace {
......
......@@ -11,11 +11,11 @@ def test_add():
s = tvm.Schedule(C.op)
# create iter var and assign them tags.
num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x)
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
# one line to build the function.
......
......@@ -22,31 +22,41 @@ def test_gemm():
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.thread_axis(None, "blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis("threadIdx.y")
CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x)
s[C].reorder(block_y, block_x, yi, xi)
_, yi = s[C].split(yi, outer=thread_y)
_, xi = s[C].split(xi, outer=thread_x)
s[C].reorder(thread_y, thread_x, yi, xi)
by, yi = s[C].split(C.op.axis[0], factor=block_factor)
bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].reorder(by, bx, yi, xi)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
ty, yi = s[C].split(yi, nparts=num_thread)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].reorder(ty, tx, yi, xi)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo)
s[CC].compute_at(s[C], thread_x)
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k)
_, xi = s[AA].split(s[AA].op.axis[0], outer=thread_y)
_, xi = s[AA].split(xi, outer=thread_x)
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y)
_, xi = s[BB].split(xi, outer=thread_x)
ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
tx, xi = s[AA].split(xi, nparts=num_thread)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
max_auto_unroll_step = 0
# lowering test
......@@ -76,9 +86,9 @@ def test_gemm():
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__":
......
......@@ -12,10 +12,9 @@ def test_sum():
s = tvm.Schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
# one line to build the function.
def check_device(device, host="stackvm"):
......@@ -52,10 +51,9 @@ def test_rfactor():
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
......@@ -88,16 +86,14 @@ def test_rfactor_threads():
k = tvm.reduce_axis((0, n))
nthread = 16
B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
tx = tvm.thread_axis((0, nthread), "threadIdx.x")
ty = tvm.thread_axis((0, nthread), "threadIdx.y")
bx = tvm.thread_axis(None, "blockIdx.x")
# schedule
s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx)
s[B].rebase(xi, ty)
s[B].rebase(s[B].op.reduce_axis[0], tx)
bx, tx = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], tx)
# one line to build the function.
......@@ -128,6 +124,6 @@ def test_rfactor_threads():
check_target("opencl")
if __name__ == "__main__":
test_rfactor_threads()
test_rfactor()
test_rfactor_threads()
test_sum()
......@@ -15,10 +15,12 @@ def test_scan():
num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_update].split(x, outer=thread_x)
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
# one line to build the function.
def check_device(device, host="stackvm"):
......
......@@ -11,10 +11,9 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx
num_thread = 256
grid_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x)
_, x = s[C].split(x, outer=thread_x)
xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
# compile to IR
bounds = tvm.schedule.InferBound(s)
......
"""Test group effect"""
import tvm
def test_scan_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
res = tvm.scan(s_init, s_update3, s_state, inputs=x)
s = tvm.Schedule(res.op)
assert s[s_update1].group is not None
assert s[s_update2].group == s[s_update1].group
# Assign within group, is valid
s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
# create a new group, for [s_update2 and s_update1]
g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
assert g2.group is not None
assert g2.group == s[s_update3].group
assert s[s_update2].group == g2
assert s[s_update1].group == g2
g2.compute_at(s[s_update3], s_update3.op.axis[1])
assert g2.attach_stage == s[s_update3]
try:
# compute outside group error.
s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
assert False
except tvm.TVMError:
pass
def test_compute_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x1].group == g
assert s[x].group == g
g.compute_at(s[x2], x2.op.axis[1])
assert g.attach_stage == s[x2]
assert g.num_child_stages == 2
def test_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x1, inputs=x)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert set(s.groups) == set([g1, g2])
assert s[x].group == g2
assert s[x1].group == g1
assert g1.group == g2
assert g2.num_child_stages == 2
assert g1.num_child_stages == 1
if __name__ == "__main__":
test_nest_group()
test_compute_group()
test_scan_group()
......@@ -29,8 +29,8 @@ def test_basic_pipeline():
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline")
xo, xi = s[B].split(B.op.axis[0], outer=px)
xo, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(xo, tvm.thread_axis("pipeline"))
xo, xi = s[B].split(xi, factor=4)
for S in stages:
s[S].compute_at(s[B], xo)
......@@ -50,8 +50,8 @@ def test_conv1d():
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline")
xo, xi = s[B].split(B.op.axis[0], outer=px)
px, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(px, tvm.thread_axis("pipeline"))
s[A].compute_at(s[B], px)
stmt = lower(s, [B])
stmt = tvm.ir_pass.SplitPipeline(stmt, False)
......
......@@ -9,15 +9,15 @@ def test_storage_sync():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op)
block_x = tvm.thread_axis(None, "blockIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, tvm.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
......
......@@ -7,8 +7,9 @@ def test_virtual_thread():
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx)
vx = tvm.thread_axis("vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
s[A2].bind(xo, vx)
xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo)
......
......@@ -36,11 +36,10 @@ def test_bound3():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op)
s[A1].set_scope("shared")
thread_x = tvm.thread_axis((0, 16), "threadIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, outer=thread_x)
xi0, xi1 = s[A2].split(xi, nparts=16)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
yo, yi = s[A2].split(A2.op.axis[1], 16)
s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo)
......@@ -60,12 +59,10 @@ def test_bound_scan():
s_scan = tvm.scan(s_init, s_update, s_state)
assert tuple(s_scan.shape) == (m, n)
s = tvm.Schedule(s_scan.op)
XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo)
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -105,22 +102,59 @@ def test_bound_rfactor():
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4)
assert(bounds[BF.op.axis[1]].extent.value == 1)
def test_bound_group_schedule():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
g.compute_at(s[x2], x2.op.axis[0])
assert s[x1].group == g
assert s[x].group == g
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent == n
def test_bound_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x, inputs=x, include_inputs=True)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x].group == g1
assert s[x1].group == g2
g2.compute_at(s[x2], x2.op.axis[0])
g1.compute_at(s[x1], s[x1].op.axis[1])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent.value == 1
assert bounds[x1.op.axis[0]].extent.value == 1
assert bounds[x1.op.axis[1]].extent == n
if __name__ == "__main__":
test_bound_nest_group()
test_bound_group_schedule()
test_bound_scan()
test_bound3()
test_bound_rfactor()
test_bound_blur()
test_bound_conv1d()
test_bound_scan()
test_bound3()
test_bound1()
test_bound2()
......@@ -59,7 +59,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
......@@ -69,8 +69,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n),
lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
......@@ -89,7 +88,7 @@ def test_scan_fix_point():
[s1_update, s2_update],
[s1, s2])
body = tvm.schedule.ScanGetBody(r0.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op, body)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op)
assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
test_scan0()
......
......@@ -35,8 +35,8 @@ def test_add_pipeline():
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.Schedule(C.op)
grid_x = tvm.thread_axis((0, 1), "pipeline")
_, x = s[C].split(C.op.axis[0], outer=grid_x)
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
print(fsplits[1].body)
......
......@@ -57,9 +57,9 @@ def test_buffer_linebuff():
# correctness checks
if (read_data_valid.get_int()):
# Derive convolution window indices
baseIdx = read_idx/(kernel_width*kernel_width)
offsetIdx = read_idx%(kernel_width*kernel_width)
yOffset = offsetIdx/kernel_width
baseIdx = read_idx // (kernel_width*kernel_width)
offsetIdx = read_idx % (kernel_width*kernel_width)
yOffset = offsetIdx // kernel_width
xOffset = offsetIdx%kernel_width
pixIndex = baseIdx + yOffset * window_width + xOffset
assert(read_data.get_int()==test_data[pixIndex])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment