Commit 400c1c48 by Tianqi Chen Committed by GitHub

[SCHEDULE] Enhance cache_write to enable layout change. (#432)

* [SCHEDULE] Enahance cache_write to enable layout change.

* more tests
parent 663d7c52
...@@ -284,8 +284,15 @@ class Schedule : public NodeRef { ...@@ -284,8 +284,15 @@ class Schedule : public NodeRef {
/*! /*!
* \brief Create a cache write tensor for producing tensor. * \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op. * The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read *
* from the corresponding cache. * This function can be used to do data layout transformation.
* If there is a split/fuse/reorder on the data parallel axis of tensor
* before cache_write is called. The intermediate cache stores
* the data in the layout as the iteration order of leave axis.
* The data will be transformed back to the original layout in the original tensor.
* User can further call compute_inline to inline the original layout and keep
* the data stored in the transformed layout.
*
* \param tensor The tensor to be produced. * \param tensor The tensor to be produced.
* \param scope The scope of the storage. * \param scope The scope of the storage.
* \return The created tensor. * \return The created tensor.
......
...@@ -248,6 +248,14 @@ class Schedule(NodeBase): ...@@ -248,6 +248,14 @@ class Schedule(NodeBase):
This will mutate the body of the tensor. This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor. A new cache stage will created before feed into the tensor.
This function can be used to support data layout transformation.
If there is a split/fuse/reorder on the data parallel axis of tensor
before cache_write is called. The intermediate cache stores
the data in the layout as the iteration order of leave axis.
The data will be transformed back to the original layout in the original tensor.
User can further call compute_inline to inline the original layout and keep
the data stored in the transformed layout.
Parameters Parameters
---------- ----------
tensor : Tensor tensor : Tensor
......
...@@ -383,8 +383,9 @@ ComputeLoopNest ComputeLoopNest::make( ...@@ -383,8 +383,9 @@ ComputeLoopNest ComputeLoopNest::make(
// make main loop nest // make main loop nest
ret.main_nest = op::MakeLoopNest( ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap);
ret.main_predicates = op::MakeBoundCheck(stage, dom_map, false, ret.main_predicates = schedule::MakeBoundCheck(
std::unordered_set<IterVar>(), ret.main_vmap); stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) { for (auto& e : ret.main_predicates) {
e = likely(e); e = likely(e);
} }
...@@ -424,8 +425,8 @@ ComputeLoopNest ComputeLoopNest::make( ...@@ -424,8 +425,8 @@ ComputeLoopNest ComputeLoopNest::make(
ret.init_nest = op::MakeLoopNest( ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap)); skip_iter, &(ret.init_vmap));
ret.init_predicates = op::MakeBoundCheck( ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, true, skip_iter, ret.init_vmap); stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) { for (auto& e : ret.init_predicates) {
e = likely(e); e = likely(e);
} }
......
...@@ -21,9 +21,9 @@ Stmt MakeCrossThreadReduction( ...@@ -21,9 +21,9 @@ Stmt MakeCrossThreadReduction(
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck( auto conds = schedule::MakeBoundCheck(
stage, dom_map, false, stage, dom_map, value_map, false,
std::unordered_set<IterVar>(), value_map); std::unordered_set<IterVar>());
size_t size = self->body.size(); size_t size = self->body.size();
CHECK_GT(size, 0); CHECK_GT(size, 0);
......
...@@ -147,91 +147,6 @@ MakeLoopNest(const Stage& stage, ...@@ -147,91 +147,6 @@ MakeLoopNest(const Stage& stage,
return nest; return nest;
} }
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<Expr> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map) {
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state);
std::vector<Expr> preds;
std::unordered_map<const Variable*, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
}
for (IterVar iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
}
return preds;
}
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
Stmt no_op = Evaluate::make(0); Stmt no_op = Evaluate::make(0);
std::vector<Stmt> nest; std::vector<Stmt> nest;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <vector> #include <vector>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
#include "../pass/arg_binder.h" #include "../pass/arg_binder.h"
#include "../schedule/message_passing.h"
namespace tvm { namespace tvm {
namespace op { namespace op {
...@@ -36,22 +37,6 @@ MakeLoopNest(const Stage& stage, ...@@ -36,22 +37,6 @@ MakeLoopNest(const Stage& stage,
bool new_loop_var, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map); std::unordered_map<IterVar, Expr>* p_value_map);
/*!
* \brief Create boundary check condition for given stage.
*
* \param stage The stage to create a loop nest.
* \param dom_map The range of each iter var.
* \param skip_ivar_domain Whether we can skip check for IterVar's original domain.
* \param skip_iter Whether skip certain iteration.
* \param value_map The result value of each IterVar.
* \return List of predicates that we need to check.
*/
std::vector<Expr>
MakeBoundCheck(const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map);
/*! /*!
* \brief Create a nest of if checking the predicates. * \brief Create a nest of if checking the predicates.
......
...@@ -274,7 +274,7 @@ Stmt ScanOpNode::BuildProvide( ...@@ -274,7 +274,7 @@ Stmt ScanOpNode::BuildProvide(
nest[begin_scan].push_back(init); nest[begin_scan].push_back(init);
nest.push_back( nest.push_back(
op::MakeIfNest( op::MakeIfNest(
op::MakeBoundCheck(stage, dom_map, false, empty, vmap))); schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty)));
return MergeNest(nest, provide); return MergeNest(nest, provide);
} }
} // namespace tvm } // namespace tvm
...@@ -7,10 +7,12 @@ ...@@ -7,10 +7,12 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./message_passing.h" #include "./message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
using namespace ir;
using namespace arith; using namespace arith;
// result = ceil((a / b)), both a and b are positive integer // result = ceil((a / b)), both a and b are positive integer
...@@ -123,8 +125,8 @@ void PassUpIndex(const Stage& stage, ...@@ -123,8 +125,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min; Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min; Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor; state[s->outer] = ComputeExpr<Div>(value, factor);
state[s->inner] = value % factor; state[s->inner] = ComputeExpr<Mod>(value, factor);
// add min if they exist // add min if they exist
if (!is_zero(outer_min)) { if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min; state[s->outer] = state[s->outer] + outer_min;
...@@ -151,6 +153,51 @@ void PassUpIndex(const Stage& stage, ...@@ -151,6 +153,51 @@ void PassUpIndex(const Stage& stage,
} }
} }
void PassDownIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state,
bool allow_missing) {
auto& state = *p_state;
for (IterVarRelation rel : stage->relations) {
if (const SplitNode* s = rel.as<SplitNode>()) {
if (!state.count(s->parent)) {
CHECK(allow_missing);
continue;
}
Range r = dom_map.at(s->inner);
CHECK(is_zero(r->min));
Expr parent = state.at(s->parent);
Expr factor = r->extent;
state[s->outer] = ComputeExpr<Div>(parent, factor);
state[s->inner] = ComputeExpr<Mod>(parent, factor);
} else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing);
continue;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
Expr inner = state.at(s->inner);
Expr outer = state.at(s->outer);
CHECK(is_zero(outer_min));
CHECK(is_zero(inner_min));
state[s->fused] = outer * factor + inner;
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (!state.count(s->rebased)) {
CHECK(allow_missing);
continue;
}
Expr value = state.at(s->parent);
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min));
state[s->rebased] = value;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
// Domain message passing. // Domain message passing.
void PassUpDomain(const SplitNode* s, void PassUpDomain(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
...@@ -349,5 +396,94 @@ void PassDownBitMaskOr(const Stage& stage, ...@@ -349,5 +396,94 @@ void PassDownBitMaskOr(const Stage& stage,
} }
} }
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else {
state[s->parent] = true;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<Expr> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state);
std::vector<Expr> preds;
std::unordered_map<const Variable*, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
}
for (IterVar iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
}
return preds;
}
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
...@@ -46,6 +46,20 @@ void PassUpIndex(const Stage& stage, ...@@ -46,6 +46,20 @@ void PassUpIndex(const Stage& stage,
bool allow_missing = false); bool allow_missing = false);
/*! /*!
* \param Downward inference of index of each IterVar.
* given index assignement of roots.
*
* \param stage The stage to operate on.
* \param dom_map The domain map of each iteration variable's domain.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
void PassDownIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state,
bool allow_missing = false);
/*!
* \param Upward inference of domain set of each IterVar. * \param Upward inference of domain set of each IterVar.
* given domain assignment of the leaves, * given domain assignment of the leaves,
* *
...@@ -76,6 +90,24 @@ void PassUpBitMaskOr(const Stage& stage, ...@@ -76,6 +90,24 @@ void PassUpBitMaskOr(const Stage& stage,
void PassDownBitMaskOr(const Stage& stage, void PassDownBitMaskOr(const Stage& stage,
std::unordered_map<IterVar, int>* p_state, std::unordered_map<IterVar, int>* p_state,
bool allow_missing = false); bool allow_missing = false);
/*!
* \brief Create boundary check predicates given remapped value of root
* \param stage The stage we operate on
* \param dom_map The domain map of each value.
* \param value_map The value map of the root iter var.
* \param skip_ivar_domain Whether we skip check for IterVar's original domain.
* \param skip_iter The set of variables to skip bound condition.
* \return List of predicates that we need to check.
*/
std::vector<Expr>
MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_MESSAGE_PASSING_H_ #endif // TVM_SCHEDULE_MESSAGE_PASSING_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <unordered_set> #include <unordered_set>
#include "./message_passing.h" #include "./message_passing.h"
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -38,6 +39,22 @@ class VarReplacer : public ir::IRMutator { ...@@ -38,6 +39,22 @@ class VarReplacer : public ir::IRMutator {
const std::unordered_map<const Variable*, Expr>& vsub_; const std::unordered_map<const Variable*, Expr>& vsub_;
}; };
Expr InjectPredicate(const Array<Expr>& predicates,
Expr body) {
using ir::Reduce;
using ir::Select;
if (predicates.size() == 0) return body;
const Reduce* reduce = body.as<Reduce>();
if (reduce) {
std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce);
n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates);
return Expr(n);
}
return Select::make(arith::ComputeReduce<ir::And>(predicates),
body,
make_zero(body.type()));
}
// Replace data flow appears in all stages given the tensor change. // Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced. // Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages, void ReplaceDataFlow(const Array<Stage>& stages,
...@@ -99,52 +116,101 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -99,52 +116,101 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return cache; return cache;
} }
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) { // Cache write and relayout the data according to loop pattern
(*this)->InvalidateCache(); Tensor CacheWriteWithReLayout(Schedule sch,
Stage orig_stage = operator[](tensor->op); const Tensor& tensor,
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); const std::string& scope) {
CHECK(compute) sch->InvalidateCache();
<< "cache write only take ComputeOp as writers"; Stage orig_stage = sch[tensor->op];
CHECK_EQ(orig_stage->relations.size(), 0U) const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
<< "Create cache_write before doing split/fuse/reorder";
compute = orig_stage->op.as<ComputeOpNode>(); std::unordered_set<IterVar> red_axis;
CHECK(compute); for (IterVar iv : compute->reduce_axis) {
Array<Expr> args; red_axis.insert(iv);
}
std::unordered_map<IterVar, Range> dom_map;
Array<IterVar> new_axis; Array<IterVar> new_axis;
std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute->axis) { for (IterVar iv : compute->axis) {
args.push_back(iv->var); dom_map[iv] = iv->dom;
IterVar new_iv = IterVarNode::make( }
iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); schedule::PassDownDomain(orig_stage, &dom_map, true);
new_axis.push_back(new_iv); std::unordered_map<const Variable*, Expr> vsub;
vsub[iv->var.get()] = new_iv->var; std::unordered_map<const Variable*, Expr> vsub2newvar;
std::vector<Expr> predicates;
{
// The source->cache
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
CHECK_EQ(iv->iter_type, kDataPar)
<< "Can only relayout with in data parallel dimensions";
Range dom = dom_map.at(iv);
IterVar new_iv = IterVarNode::make(
dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
if (is_one(dom->min)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = iv->var;
vsub2newvar[iv->var.get()] = new_iv->var;
}
}
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
for (IterVar iv : compute->reduce_axis) {
skip_bound_check.insert(iv);
}
schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
predicates = schedule::MakeBoundCheck(
orig_stage, dom_map, value_map, true, skip_bound_check);
// The root axis
for (IterVar iv : compute->axis) {
vsub[iv->var.get()] = value_map.at(iv);
}
}
Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]);
body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar).Mutate(body);
// The reader args
Array<Expr> args;
{
// cache->compute
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : compute->axis) {
value_map[iv] = iv->var;
}
schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
args.push_back(value_map.at(iv));
}
} }
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make( Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, compute->tag, new_axis, {body}); compute->name + "." + scope, compute->tag, new_axis, {body});
Tensor cache_tensor = cache_op.output(0); Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make( Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis, compute->name, compute->tag, compute->axis,
{cache_tensor(args)}); {cache_tensor(args)});
// The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0); vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow(sch->stages, &vmap);
// mutate orig stage // mutate orig stage
orig_stage->op = orig_new_op; orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
orig_stage->relations = Array<IterVarRelation>();
// create schedule for new cached stage. // create schedule for new cached stage.
ArrayNode* stages = (*this)->stages.CopyOnWrite(); ArrayNode* stages = sch->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage); size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op); Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope); cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size()); CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos, stages->data.insert(stages->data.begin() + pos,
cache_stage.node_); cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage); sch->stage_map.Set(cache_op, cache_stage);
// Update group // Update group
cache_stage->group = orig_stage->group; cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) { if (cache_stage->group.defined()) {
...@@ -153,6 +219,19 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -153,6 +219,19 @@ Tensor Schedule::cache_write(const Tensor& tensor,
return cache_tensor; return cache_tensor;
} }
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK_EQ(compute->num_outputs(), 1)
<< "cache write only support single output ComputeOp";
return CacheWriteWithReLayout(*this, tensor, scope);
}
void RebaseNonZeroMinLoop(const Schedule& sch) { void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map; std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
...@@ -295,16 +374,23 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -295,16 +374,23 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
touch_map[axis] = 1; touch_map[axis] = 1;
schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true); schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true); schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
// Verify normal axis are not touched. // Verify normal axis are not touched.
for (IterVar iv : compute_op->axis) { for (IterVar iv : compute_op->axis) {
CHECK(!touch_map.count(iv)) CHECK(!touch_map.count(iv))
<< "Factor axis touches normal axis."; << "Factor axis touches normal axis.";
skip_bound_check.insert(iv);
} }
// Get the replace index // Get the replace index
std::unordered_map<IterVar, Range> dom_map; std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : compute_op->reduce_axis) { for (IterVar iv : compute_op->reduce_axis) {
if (touch_map.count(iv)) dom_map[iv] = iv->dom; if (touch_map.count(iv)) {
dom_map[iv] = iv->dom;
} else {
skip_bound_check.insert(iv);
}
} }
schedule::PassDownDomain(reduce_stage, &dom_map, true); schedule::PassDownDomain(reduce_stage, &dom_map, true);
for (IterVar iv : reduce_stage->leaf_iter_vars) { for (IterVar iv : reduce_stage->leaf_iter_vars) {
...@@ -318,6 +404,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -318,6 +404,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
} }
} }
schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true); schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
std::vector<Expr> predicates = schedule::MakeBoundCheck(
reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node. // Get the factored op node.
auto n = std::make_shared<ComputeOpNode>(); auto n = std::make_shared<ComputeOpNode>();
n->name = compute_op->name + ".rf"; n->name = compute_op->name + ".rf";
...@@ -339,8 +428,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -339,8 +428,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
int idx = tensor->value_index; int idx = tensor->value_index;
const Reduce* reduce = compute_op->body[idx].as<Reduce>(); const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions"; CHECK(reduce) << "Can only rfactor non-inline reductions";
Expr predicate = reduce->condition; predicates.push_back(reduce->condition);
Expr predicate = arith::ComputeReduce<ir::And>(predicates);
std::unordered_map<const Variable*, Expr> vsub; std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute_op->reduce_axis) { for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) { if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv); n->reduce_axis.push_back(iv);
...@@ -348,16 +440,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -348,16 +440,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
CHECK(value_map.count(iv)); CHECK(value_map.count(iv));
Expr index = value_map.at(iv); Expr index = value_map.at(iv);
vsub[iv->var.get()] = index; vsub[iv->var.get()] = index;
if (!index.same_as(iv->var)) {
Expr cond = (index < dom_map.at(iv)->extent);
if (is_one(predicate)) {
predicate = cond;
} else {
predicate = predicate && cond;
}
}
} }
} }
// Copy touched axis. // Copy touched axis.
for (IterVar iv : reduce_stage->leaf_iter_vars) { for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) { if (touch_map.count(iv) && !iv.same_as(axis)) {
...@@ -453,4 +538,5 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -453,4 +538,5 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
reduce_stage->relations = Array<IterVarRelation>(); reduce_stage->relations = Array<IterVarRelation>();
return factor_tensors; return factor_tensors;
} }
} // namespace tvm } // namespace tvm
...@@ -55,7 +55,6 @@ def test_schedule_scan(): ...@@ -55,7 +55,6 @@ def test_schedule_scan():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1) assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_auto_inline(): def test_auto_inline():
m = tvm.var('m') m = tvm.var('m')
...@@ -160,7 +159,58 @@ def test_schedule_cache(): ...@@ -160,7 +159,58 @@ def test_schedule_cache():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout1():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.create_schedule(C.op)
s[C].reorder(C.op.axis[1], C.op.axis[0])
CC = s.cache_write(C, "global")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout2():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m*4, n), name='A')
B = tvm.placeholder((m*4, n), name='B')
C = tvm.compute(A.shape, lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.create_schedule(C.op)
x, y = C.op.axis
xo, xi = s[C].split(x, factor=4)
s[C].reorder(xo, y, xi)
CC = s.cache_write(C, "global")
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout3():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m*4, n), name='A')
B = tvm.placeholder((m*4, n), name='B')
k = tvm.reduce_axis((0, n), "k")
C = tvm.compute((A.shape[0],),
lambda i: tvm.sum(A(i, k) * B(i, k), axis=k), name='C')
s = tvm.create_schedule(C.op)
x = C.op.axis[0]
xo, xi = s[C].split(x, factor=4)
CC = s.cache_write(C, "global")
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_cache_relayout4()
test_schedule_cache_relayout3()
test_schedule_cache_relayout2()
test_schedule_cache_relayout1()
test_schedule_const_bound() test_schedule_const_bound()
test_scan_inline1() test_scan_inline1()
test_scan_inline2() test_scan_inline2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment