Commit d114dfc9 by Tianqi Chen Committed by GitHub

[SCHEDULE] Mutate dataflow in schedule, refactor Stage (#44)

parent 820a8597
......@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Tensor placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
......@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
......@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
} // namespace tvm
......
......@@ -132,6 +132,13 @@ class Stage : public NodeRef {
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
......@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op);
}
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
......@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
// declare container type
using ContainerType = ScheduleNode;
};
......@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
*/
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief The operation of stage, can be different from original op. */
Operation op;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
......@@ -255,6 +295,11 @@ class StageNode : public Node {
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array<IterVar> outermost_threads;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
......@@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief Whether this is an output stage */
bool is_output{false};
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
v->Visit("is_output", &is_output);
}
static constexpr const char* _type_key = "Stage";
......@@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
* \brief list of all stages for non-placeholder ops.
* The stages are sorted in dependency order.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}
......@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone);
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}
inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
......
......@@ -63,7 +63,6 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
......
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
from . import tensor as _tensor
from . import collections as _collections
@register_node
class Buffer(NodeBase):
......@@ -41,6 +42,53 @@ class Schedule(NodeBase):
"""
_api_internal._ScheduleNormalize(self)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
......@@ -104,6 +152,18 @@ class Stage(NodeBase):
"""
return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
......
......@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0],
*ret = placeholder(args[0],
args[1],
args[2]);
});
......@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_StageOutermostThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
});
TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
......@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});
TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheWrite)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});
} // namespace tvm
......@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
......@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
......@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
......
......@@ -6,9 +6,11 @@
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
namespace schedule {
using namespace ir;
class ElemWiseDetector : public IRVisitor {
class ElemWiseDetector : public ir::IRVisitor {
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
......@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
}
for (size_t i = 0; i < axis_.size(); ++i) {
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if (!axis[i].same_as(axis_[i]->var)) {
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_ = false;
return;
}
......@@ -52,21 +51,9 @@ bool IsElemWise(const Operation& op) {
return false;
}
} // namespace ir
namespace schedule {
void AutoInlineElemWise(Schedule sch) {
for (Stage s : sch->stages) {
if (!s.is_scheduled() && ir::IsElemWise(s->op)) {
bool is_root = false;
for (auto r : sch->roots) {
if (r == s->op) {
is_root = true;
break;
}
}
if (!is_root)
if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
s.compute_inline();
}
}
......
......@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
......@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
Range r = arith::Union(tdom.data.at(i)).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
(*rmap)[compute->axis[i]] = r;
}
......@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
direct_consume_by_parent = true;
}
}
} else {
LOG(INFO) << "not in feed graph consumer = " << stage->op;
}
}
// The relax set
......@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
}
FeedGraph CreateFeedGraph(const Schedule& sch) {
auto g = CreateReadGraph(sch->roots);
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
auto g = CreateReadGraph(roots);
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
......@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch);
AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
......
......@@ -3,6 +3,8 @@
* \file schedule.cc
*/
#include <tvm/schedule.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "./graph.h"
namespace tvm {
......@@ -10,7 +12,8 @@ namespace tvm {
namespace {
// find first occurance location in leaf
size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
template<typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
const Node* n = v.get();
for (size_t i = 0; i < array_node->data.size(); ++i) {
if (array_node->data[i].get() == n) return i;
......@@ -19,10 +22,10 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
}
size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
size_t pos = FindIterVar(leaf_vars, v);
size_t pos = FindNodeRef(leaf_vars, v);
if (pos < leaf_vars->data.size()) return pos;
if (FindIterVar(all_vars, v) < all_vars->data.size()) {
if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
LOG(FATAL) << "Operate on iter var " << v
<< "that has already been splitted";
} else {
......@@ -68,8 +71,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>();
n->op = op;
n->origin_op = op;
n->all_iter_vars = op->root_iter_vars();
n->leaf_iter_vars = op->root_iter_vars();
n->leaf_iter_vars = n->all_iter_vars;
node_ = n;
}
......@@ -89,7 +93,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
CHECK(found)
<< "Cannot find the specified axis in parent stage's leaf_iter_vars";
<< "Cannot find the axis in parent's leaf_iter_vars or outermost_threads";
return *this;
}
......@@ -176,13 +180,63 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return *this;
}
Stage& Stage::outermost_threads(Array<IterVar> threads) {
StageNode* self = operator->();
CHECK(self->op.as<ScanOpNode>())
<< "outermost_threads is only valid for composite ops such as ScanOp";
CHECK_EQ(self->outermost_threads.size(), 0U)
<< "Already set outermost_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<std::shared_ptr<Node> > temp;
for (IterVar iv : threads) {
temp.push_back(iv.node_);
}
leaf_vars->data.insert(
leaf_vars->data.begin(), temp.begin(), temp.end());
all_vars->data.insert(
all_vars->data.end(), temp.begin(), temp.end());
(*this)->outermost_threads = threads;
return *this;
}
inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
if (it != self->iter_var_attrs.end()) {
CHECK_EQ((*it).second->iter_type, attr->iter_type)
<< "IterVar's is already set to "
<< (*it).second << " instead of " << attr;
} else {
self->iter_var_attrs.Set(var, attr);
}
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kVectorized));
return *this;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kUnrolled));
return *this;
}
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
n->roots = ops;
auto g = schedule::CreateReadGraph(n->roots);
Array<Operation> post_order = schedule::PostDFSOrder(n->roots, g);
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
// output set.
std::unordered_set<Operation> output_set;
for (Operation x : ops) {
output_set.insert(x);
}
for (Operation op : post_order) {
Stage stage(op);
stage->is_output = output_set.count(op);
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
}
......@@ -237,7 +291,7 @@ void Schedule::normalize() {
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindIterVar(leaf_vars, iv);
size_t idx = FindNodeRef(leaf_vars, iv);
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb");
......@@ -262,35 +316,197 @@ IterVarAttr::IterVarAttr(IterVarType t) {
node_ = n;
}
inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
if (it != self->iter_var_attrs.end()) {
CHECK_EQ((*it).second->iter_type, attr->iter_type)
<< "IterVar's is already set to "
<< (*it).second << " instead of " << attr;
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
using ir::TensorKey;
// The replacer of cache.
class TensorReplacer : public ir::IRMutator {
public:
TensorReplacer(const std::unordered_map<TensorKey, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
ir::TensorKey key{op->func, op->value_index};
auto it = vmap_.find(key);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<TensorKey, Tensor>& vmap_;
};
class VarReplacer : public ir::IRMutator {
public:
explicit VarReplacer(
const std::unordered_map<const Variable*, Expr>& vsub)
: vsub_(vsub) {}
Expr Mutate_(const Variable* op, const Expr& e) {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
return e;
}
private:
const std::unordered_map<const Variable*, Expr>& vsub_;
};
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages,
std::unordered_map<TensorKey, Tensor>* vmap) {
for (Stage s : stages) {
if (s->op.as<ComputeOpNode>()) {
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
TensorReplacer repl(*vmap);
Expr body = repl.Mutate(compute->body);
if (repl.found) {
Operation op = ComputeOpNode::make(
compute->name, compute->axis, body);
(*vmap)[TensorKey{s->op, 0}] = op.output(0);
s->op = op;
}
} else if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
std::shared_ptr<ScanOpNode> n =
std::make_shared<ScanOpNode>(*scan);
// copy on write semantics ganrantees correctness
for (size_t i = 0; i < n->init.size(); ++i) {
TensorKey key{n->init[i]->op, n->init[i]->value_index};
if (vmap->count(key)) {
n->init.Set(i, vmap->at(key));
}
}
for (size_t i = 0; i < n->update.size(); ++i) {
TensorKey key{n->update[i]->op, n->update[i]->value_index};
if (vmap->count(key)) {
n->update.Set(i, vmap->at(key));
}
}
if (!n->init.same_as(scan->init) ||
!n->update.same_as(scan->update)) {
Operation op(n);
for (int i = 0; i < op->num_outputs(); ++i) {
(*vmap)[TensorKey{s->op, i}] = op.output(i);
}
s->op = op;
}
} else if (s->op.as<PlaceholderOpNode>()) {
} else {
self->iter_var_attrs.Set(var, attr);
LOG(FATAL) << "unhandled problem";
}
}
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kVectorized));
return *this;
Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers) {
// create identity mapping.
std::ostringstream os;
os << tensor->op->name;
if (tensor->op->num_outputs() != 1) {
os << ".v" << tensor->value_index;
}
os << "." << scope;
Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
return tensor(Array<Expr>(i.begin(), i.end()));
}, os.str());
std::unordered_map<TensorKey, Tensor> vsub;
vsub[TensorKey{tensor->op, tensor->value_index}] = cache;
std::unordered_map<TensorKey, Tensor> vmap;
for (Operation op : readers) {
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute)
<< "cache read only take ComputeOp as readers";
Stage s = operator[](op);
compute = s->op.as<ComputeOpNode>();
TensorReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
CHECK(repl.found)
<< "Cannot find " << tensor
<< " in the body of specified reader" << op;
Operation repl_op = ComputeOpNode::make(
compute->name, compute->axis, body);
vmap[TensorKey{s->op, 0}] = repl_op.output(0);
s->op = repl_op;
}
ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op));
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage);
return cache;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kUnrolled));
return *this;
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK(!orig_stage.is_scheduled())
<< "Create cache_write before doing split/fuse/reorder";
compute = orig_stage->op.as<ComputeOpNode>();
CHECK(compute);
Array<Expr> args;
Array<IterVar> new_axis;
std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute->axis) {
args.push_back(iv->var);
IterVar new_iv(iv->dom, iv->var->name_hint + ".c");
new_axis.push_back(new_iv);
vsub[iv->var.get()] = new_iv->var;
}
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, body);
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis,
cache_tensor(args));
std::unordered_map<TensorKey, Tensor> vmap;
vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
// create schedule for new cached stage.
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage);
return cache_tensor;
}
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
......@@ -23,7 +23,8 @@ using namespace ir;
// Two private scope marks
namespace attr {
constexpr const char* loop_scope = "loop_scope";
constexpr const char* scan_scope = "scan_scope";
constexpr const char* scan_update_scope = "scan_update_scope";
constexpr const char* scan_init_scope = "scan_init_scope";
} // namespace attr
/*!
......@@ -280,23 +281,31 @@ Stmt MakeLoop(const Stage& s,
if (init.defined()) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> reduce_state;
std::unordered_map<IterVar, int> update_state;
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (compute) {
for (IterVar iv : compute->reduce_axis) {
reduce_state[iv] = 2;
update_state[iv] = 2;
}
for (IterVar iv : compute->axis) {
reduce_state[iv] = 1;
update_state[iv] = 1;
}
} else if (scan) {
update_state[scan->scan_axis] = 2;
for (IterVar iv : s->outermost_threads) {
update_state[iv] = 1;
}
}
// find which iter var is related to reduction and which is related to axis.
PassDownFlag(s, &reduce_state);
PassDownFlag(s, &update_state);
auto leaf_iter_vars = s->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = reduce_state.at(iv);
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
......@@ -304,7 +313,7 @@ Stmt MakeLoop(const Stage& s,
}
// skip loops that does not relates to axis.
std::unordered_map<IterVar, bool> skip_iter;
for (auto kv : reduce_state) {
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter[kv.first] = true;
}
......@@ -422,7 +431,10 @@ Stmt MakePipeline(const Stage& s,
} else if (scan) {
// Provide is done by the sub operations.
provide = AttrStmt::make(
s->op, attr::scan_scope, scan->scan_axis->var,
s->op, attr::scan_update_scope, scan->scan_axis->var,
Evaluate::make(0));
init = AttrStmt::make(
s->op, attr::scan_init_scope, 0,
Evaluate::make(0));
} else {
LOG(FATAL) << "not supported op " << s->op->type_key();
......@@ -472,7 +484,9 @@ class InjectAttach : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::loop_scope) {
if (op->node == stage_->attach_ivar) {
CHECK_NE(producer_.size(), 0U);
if (op->node == stage_->attach_ivar &&
producer_.back() == stage_->attach_stage->op.get()) {
CHECK(!found_attach);
found_attach = true;
stmt = AttrStmt::make(
......@@ -482,6 +496,16 @@ class InjectAttach : public IRMutator {
}
return stmt;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (op->is_producer) {
producer_.push_back(op->func.get());
Stmt ret = IRMutator::Mutate_(op, s);
producer_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
// whether attach point is found
bool found_attach{false};
......@@ -490,6 +514,8 @@ class InjectAttach : public IRMutator {
const Stage& stage_;
// domain map
const Map<IterVar, Range>& dom_map_;
// internal stack about realization scope.
std::vector<const Node*> producer_;
};
// inject the operator's realization on the stmt.
......@@ -505,21 +531,11 @@ class InjectScanStep : public IRMutator {
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
if (is_init_) {
const ProducerConsumer* op = stmt.as<ProducerConsumer>();
if (op != nullptr &&
op->is_producer &&
op->func.same_as(scan_op_)) {
stmt = ProducerConsumer::make(
op->func, true,
MakePipeline(stage_, dom_map_, op->body));
found_attach = true;
}
} else {
// update
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == attr::scan_scope) {
((op->type_key == attr::scan_update_scope && !is_init_) ||
(op->type_key == attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
stmt = AttrStmt::make(
......@@ -527,7 +543,6 @@ class InjectScanStep : public IRMutator {
MakePipeline(stage_, dom_map_, op->body));
}
}
}
return stmt;
}
......@@ -561,8 +576,15 @@ Stmt InjectInline(const Operation op, Stmt body) {
class SchedulePostProc : public IRMutator {
public:
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (to_remove_.count(op->func.get())) {
return this->Mutate(op->body);
auto it = replace_op_.find(op->func.get());
if (it != replace_op_.end()) {
Stmt body = this->Mutate(op->body);
if (it->second.defined()) {
return ProducerConsumer::make(
it->second, op->is_producer, body);
} else {
return body;
}
} else {
return IRMutator::Mutate_(op, s);
}
......@@ -579,23 +601,40 @@ class SchedulePostProc : public IRMutator {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_scope) {
} else if (op->type_key == attr::scan_init_scope) {
return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body);
} else if (op->type_key == ir::attr::realize_scope) {
if (to_remove_.count(op->node.get())) {
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
Stmt ret = AttrStmt::make(
it->second, op->type_key, op->value, op->body);
return this->Mutate_(ret.as<AttrStmt>(), ret);
} else {
return this->Mutate(op->body);
}
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
if (replace_.count(key)) {
auto it = replace_realize_.find(key);
if (it != replace_realize_.end()) {
if (it->second.defined()) {
Stmt ret = Realize::make(
it->second->op, it->second->value_index,
op->type, op->bounds, op->condition, op->body);
return this->Mutate_(ret.as<Realize>(), ret);
} else {
return this->Mutate(op->body);
}
} else {
return IRMutator::Mutate_(op, s);
}
......@@ -603,8 +642,8 @@ class SchedulePostProc : public IRMutator {
Stmt Mutate_(const Provide* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
auto it = replace_.find(key);
if (it != replace_.end()) {
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second.first;
Stmt ret = Provide::make(
dst->op, dst->value_index, op->value,
......@@ -616,10 +655,10 @@ class SchedulePostProc : public IRMutator {
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op != nullptr && op->call_type == Call::Halide) {
if (op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = replace_.find(key);
if (it != replace_.end()) {
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second.first;
Expr ret = Call::make(
op->type, dst->op->name,
......@@ -642,22 +681,32 @@ class SchedulePostProc : public IRMutator {
void Init(const Schedule& sch) {
for (Stage s : sch->stages) {
if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->op.output(i);
Tensor t = s->origin_op.output(i);
AddReplace(scan->init[i], t, Expr());
AddReplace(scan->update[i], t, scan->scan_axis->var);
AddReplace(scan->state_placeholder[i], t, Expr());
}
} else if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0);
AddReplace(s->op.output(0), target,
Expr(), target, s->origin_op);
}
}
}
private:
void AddReplace(Tensor src, Tensor dst, Expr head_idx) {
replace_[TensorKey{src->op, src->value_index}]
= std::make_pair(dst, head_idx);
to_remove_.insert(src->op.get());
void AddReplace(Tensor src,
Tensor dst,
Expr head_idx,
Tensor repl_realize = Tensor(),
Operation repl_op = Operation()) {
TensorKey key{src->op, src->value_index};
replace_buffer_[key] = std::make_pair(dst, head_idx);
replace_realize_[key] = repl_realize;
replace_op_[src->op.get()] = repl_op;
}
Array<Expr> RewriteArgs(Expr head, Array<Expr> args) {
if (!head.defined()) return args;
......@@ -670,9 +719,11 @@ class SchedulePostProc : public IRMutator {
// The scan value
std::unordered_map<const Variable*, Expr> var_value_;
// buffer replacement
std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_;
// replaced functions
std::unordered_set<const Node*> to_remove_;
std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_buffer_;
// buffere realization to be replaced
std::unordered_map<TensorKey, Tensor> replace_realize_;
// replace producer consumer.
std::unordered_map<const Node*, Operation> replace_op_;
};
Stmt ScheduleOps(
......@@ -724,7 +775,9 @@ Stmt ScheduleOps(
InjectAttach mutator(s, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point";
<< "did not find attachment point for " << s << " in"
<< s->attach_stage->op << " x "
<< body;
}
}
SchedulePostProc post_proc;
......
......@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) {
auto C = compute({m, n}, [&](Var i, Var j) {
return A[i][j];
}, "C");
......@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) {
auto C = compute({m, n}, [&](Var i, Var j) {
return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
}, "C");
LOG(INFO) << C->op.as<ComputeOpNode>()->body;
......
......@@ -2,17 +2,6 @@ import tvm
from tvm.addon import nvcc_compiler
import numpy as np
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx
@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code
def test_gemm():
# graph
nn = 1024
......@@ -22,21 +11,14 @@ def test_gemm():
l = n
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
AA = tvm.compute(A.shape, lambda *i : A(*i), name="AA")
BB = tvm.compute(B.shape, lambda *i : B(*i), name="BB")
k = tvm.IterVar((0, l), name='k')
CC = tvm.compute(
C = tvm.compute(
(n, m),
lambda ii, jj: tvm.sum(AA[ii, k] * BB[jj, k], axis=k),
lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
name='CC')
C = tvm.compute(CC.shape, lambda *i: CC(*i), name="C")
# schedule
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
s[AA].set_scope("shared")
s[BB].set_scope("shared")
scale = 8
num_thread = 8
block_factor = scale * num_thread
......@@ -45,6 +27,9 @@ def test_gemm():
block_y = tvm.IterVar(thread_tag="blockIdx.y")
thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y")
CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x)
s[C].reorder(block_y, block_x, yi, xi)
......@@ -92,8 +77,8 @@ def test_gemm():
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
#tvm.init_opencl()
#check_device("opencl")
tvm.init_opencl()
check_device("opencl")
if __name__ == "__main__":
test_gemm()
......@@ -22,13 +22,13 @@ def test_schedule_create():
json_str = tvm.save_json(s)
s_loaded = tvm.load_json(json_str)
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body))
assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
# pickle unpickle
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body))
assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
def test_reorder():
m = tvm.Var('m')
......
......@@ -74,6 +74,20 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.Schedule(C.op)
AA = s.cache_read(A, "shared", readers=[C])
CC = s.cache_write(C, "shared")
s[AA].compute_at(s[CC], CC.op.axis[0])
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__":
test_schedule_scan()
......@@ -81,3 +95,4 @@ if __name__ == "__main__":
test_schedule1()
test_schedule2()
test_auto_inline()
test_schedule_cache()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment