Commit 400c1c48 by Tianqi Chen Committed by GitHub

[SCHEDULE] Enhance cache_write to enable layout change. (#432)

* [SCHEDULE] Enahance cache_write to enable layout change.

* more tests
parent 663d7c52
......@@ -284,8 +284,15 @@ class Schedule : public NodeRef {
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
*
* This function can be used to do data layout transformation.
* If there is a split/fuse/reorder on the data parallel axis of tensor
* before cache_write is called. The intermediate cache stores
* the data in the layout as the iteration order of leave axis.
* The data will be transformed back to the original layout in the original tensor.
* User can further call compute_inline to inline the original layout and keep
* the data stored in the transformed layout.
*
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
......
......@@ -248,6 +248,14 @@ class Schedule(NodeBase):
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
This function can be used to support data layout transformation.
If there is a split/fuse/reorder on the data parallel axis of tensor
before cache_write is called. The intermediate cache stores
the data in the layout as the iteration order of leave axis.
The data will be transformed back to the original layout in the original tensor.
User can further call compute_inline to inline the original layout and keep
the data stored in the transformed layout.
Parameters
----------
tensor : Tensor
......
......@@ -383,8 +383,9 @@ ComputeLoopNest ComputeLoopNest::make(
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap);
ret.main_predicates = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), ret.main_vmap);
ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>());
for (auto& e : ret.main_predicates) {
e = likely(e);
}
......@@ -424,8 +425,8 @@ ComputeLoopNest ComputeLoopNest::make(
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap));
ret.init_predicates = op::MakeBoundCheck(
stage, dom_map, true, skip_iter, ret.init_vmap);
ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
......
......@@ -21,9 +21,9 @@ Stmt MakeCrossThreadReduction(
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
auto conds = schedule::MakeBoundCheck(
stage, dom_map, value_map, false,
std::unordered_set<IterVar>());
size_t size = self->body.size();
CHECK_GT(size, 0);
......
......@@ -147,91 +147,6 @@ MakeLoopNest(const Stage& stage,
return nest;
}
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<Expr> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map) {
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state);
std::vector<Expr> preds;
std::unordered_map<const Variable*, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
}
for (IterVar iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
}
return preds;
}
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
Stmt no_op = Evaluate::make(0);
std::vector<Stmt> nest;
......
......@@ -13,6 +13,7 @@
#include <vector>
#include "../pass/ir_util.h"
#include "../pass/arg_binder.h"
#include "../schedule/message_passing.h"
namespace tvm {
namespace op {
......@@ -36,22 +37,6 @@ MakeLoopNest(const Stage& stage,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map);
/*!
* \brief Create boundary check condition for given stage.
*
* \param stage The stage to create a loop nest.
* \param dom_map The range of each iter var.
* \param skip_ivar_domain Whether we can skip check for IterVar's original domain.
* \param skip_iter Whether skip certain iteration.
* \param value_map The result value of each IterVar.
* \return List of predicates that we need to check.
*/
std::vector<Expr>
MakeBoundCheck(const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map);
/*!
* \brief Create a nest of if checking the predicates.
......
......@@ -274,7 +274,7 @@ Stmt ScanOpNode::BuildProvide(
nest[begin_scan].push_back(init);
nest.push_back(
op::MakeIfNest(
op::MakeBoundCheck(stage, dom_map, false, empty, vmap)));
schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty)));
return MergeNest(nest, provide);
}
} // namespace tvm
......@@ -7,10 +7,12 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace schedule {
using namespace ir;
using namespace arith;
// result = ceil((a / b)), both a and b are positive integer
......@@ -123,8 +125,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
state[s->outer] = ComputeExpr<Div>(value, factor);
state[s->inner] = ComputeExpr<Mod>(value, factor);
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
......@@ -151,6 +153,51 @@ void PassUpIndex(const Stage& stage,
}
}
void PassDownIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state,
bool allow_missing) {
auto& state = *p_state;
for (IterVarRelation rel : stage->relations) {
if (const SplitNode* s = rel.as<SplitNode>()) {
if (!state.count(s->parent)) {
CHECK(allow_missing);
continue;
}
Range r = dom_map.at(s->inner);
CHECK(is_zero(r->min));
Expr parent = state.at(s->parent);
Expr factor = r->extent;
state[s->outer] = ComputeExpr<Div>(parent, factor);
state[s->inner] = ComputeExpr<Mod>(parent, factor);
} else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing);
continue;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
Expr inner = state.at(s->inner);
Expr outer = state.at(s->outer);
CHECK(is_zero(outer_min));
CHECK(is_zero(inner_min));
state[s->fused] = outer * factor + inner;
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (!state.count(s->rebased)) {
CHECK(allow_missing);
continue;
}
Expr value = state.at(s->parent);
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min));
state[s->rebased] = value;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
// Domain message passing.
void PassUpDomain(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
......@@ -349,5 +396,94 @@ void PassDownBitMaskOr(const Stage& stage,
}
}
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else {
state[s->parent] = true;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<Expr> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state);
std::vector<Expr> preds;
std::unordered_map<const Variable*, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
}
for (IterVar iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
}
return preds;
}
} // namespace schedule
} // namespace tvm
......@@ -46,6 +46,20 @@ void PassUpIndex(const Stage& stage,
bool allow_missing = false);
/*!
* \param Downward inference of index of each IterVar.
* given index assignement of roots.
*
* \param stage The stage to operate on.
* \param dom_map The domain map of each iteration variable's domain.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
void PassDownIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state,
bool allow_missing = false);
/*!
* \param Upward inference of domain set of each IterVar.
* given domain assignment of the leaves,
*
......@@ -76,6 +90,24 @@ void PassUpBitMaskOr(const Stage& stage,
void PassDownBitMaskOr(const Stage& stage,
std::unordered_map<IterVar, int>* p_state,
bool allow_missing = false);
/*!
* \brief Create boundary check predicates given remapped value of root
* \param stage The stage we operate on
* \param dom_map The domain map of each value.
* \param value_map The value map of the root iter var.
* \param skip_ivar_domain Whether we skip check for IterVar's original domain.
* \param skip_iter The set of variables to skip bound condition.
* \return List of predicates that we need to check.
*/
std::vector<Expr>
MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter);
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_MESSAGE_PASSING_H_
......@@ -9,6 +9,7 @@
#include <unordered_set>
#include "./message_passing.h"
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
......@@ -38,6 +39,22 @@ class VarReplacer : public ir::IRMutator {
const std::unordered_map<const Variable*, Expr>& vsub_;
};
Expr InjectPredicate(const Array<Expr>& predicates,
Expr body) {
using ir::Reduce;
using ir::Select;
if (predicates.size() == 0) return body;
const Reduce* reduce = body.as<Reduce>();
if (reduce) {
std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce);
n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates);
return Expr(n);
}
return Select::make(arith::ComputeReduce<ir::And>(predicates),
body,
make_zero(body.type()));
}
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages,
......@@ -99,52 +116,101 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return cache;
}
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK_EQ(orig_stage->relations.size(), 0U)
<< "Create cache_write before doing split/fuse/reorder";
compute = orig_stage->op.as<ComputeOpNode>();
CHECK(compute);
Array<Expr> args;
// Cache write and relayout the data according to loop pattern
Tensor CacheWriteWithReLayout(Schedule sch,
const Tensor& tensor,
const std::string& scope) {
sch->InvalidateCache();
Stage orig_stage = sch[tensor->op];
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
std::unordered_set<IterVar> red_axis;
for (IterVar iv : compute->reduce_axis) {
red_axis.insert(iv);
}
std::unordered_map<IterVar, Range> dom_map;
Array<IterVar> new_axis;
std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute->axis) {
args.push_back(iv->var);
IterVar new_iv = IterVarNode::make(
iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
vsub[iv->var.get()] = new_iv->var;
dom_map[iv] = iv->dom;
}
schedule::PassDownDomain(orig_stage, &dom_map, true);
std::unordered_map<const Variable*, Expr> vsub;
std::unordered_map<const Variable*, Expr> vsub2newvar;
std::vector<Expr> predicates;
{
// The source->cache
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
CHECK_EQ(iv->iter_type, kDataPar)
<< "Can only relayout with in data parallel dimensions";
Range dom = dom_map.at(iv);
IterVar new_iv = IterVarNode::make(
dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
if (is_one(dom->min)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = iv->var;
vsub2newvar[iv->var.get()] = new_iv->var;
}
}
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
for (IterVar iv : compute->reduce_axis) {
skip_bound_check.insert(iv);
}
schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
predicates = schedule::MakeBoundCheck(
orig_stage, dom_map, value_map, true, skip_bound_check);
// The root axis
for (IterVar iv : compute->axis) {
vsub[iv->var.get()] = value_map.at(iv);
}
}
Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]);
body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar).Mutate(body);
// The reader args
Array<Expr> args;
{
// cache->compute
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : compute->axis) {
value_map[iv] = iv->var;
}
schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
args.push_back(value_map.at(iv));
}
}
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, compute->tag, new_axis, {body});
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis,
{cache_tensor(args)});
// The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap);
ReplaceDataFlow(sch->stages, &vmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
orig_stage->relations = Array<IterVarRelation>();
// create schedule for new cached stage.
ArrayNode* stages = (*this)->stages.CopyOnWrite();
ArrayNode* stages = sch->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
......@@ -153,6 +219,19 @@ Tensor Schedule::cache_write(const Tensor& tensor,
return cache_tensor;
}
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK_EQ(compute->num_outputs(), 1)
<< "cache write only support single output ComputeOp";
return CacheWriteWithReLayout(*this, tensor, scope);
}
void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) {
......@@ -295,16 +374,23 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
touch_map[axis] = 1;
schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
// Verify normal axis are not touched.
for (IterVar iv : compute_op->axis) {
CHECK(!touch_map.count(iv))
<< "Factor axis touches normal axis.";
skip_bound_check.insert(iv);
}
// Get the replace index
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : compute_op->reduce_axis) {
if (touch_map.count(iv)) dom_map[iv] = iv->dom;
if (touch_map.count(iv)) {
dom_map[iv] = iv->dom;
} else {
skip_bound_check.insert(iv);
}
}
schedule::PassDownDomain(reduce_stage, &dom_map, true);
for (IterVar iv : reduce_stage->leaf_iter_vars) {
......@@ -318,6 +404,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
}
}
schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
std::vector<Expr> predicates = schedule::MakeBoundCheck(
reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node.
auto n = std::make_shared<ComputeOpNode>();
n->name = compute_op->name + ".rf";
......@@ -339,8 +428,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
int idx = tensor->value_index;
const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
Expr predicate = reduce->condition;
predicates.push_back(reduce->condition);
Expr predicate = arith::ComputeReduce<ir::And>(predicates);
std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv);
......@@ -348,16 +440,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
CHECK(value_map.count(iv));
Expr index = value_map.at(iv);
vsub[iv->var.get()] = index;
if (!index.same_as(iv->var)) {
Expr cond = (index < dom_map.at(iv)->extent);
if (is_one(predicate)) {
predicate = cond;
} else {
predicate = predicate && cond;
}
}
}
}
// Copy touched axis.
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) {
......@@ -453,4 +538,5 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
reduce_stage->relations = Array<IterVarRelation>();
return factor_tensors;
}
} // namespace tvm
......@@ -55,7 +55,6 @@ def test_schedule_scan():
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_auto_inline():
m = tvm.var('m')
......@@ -160,7 +159,58 @@ def test_schedule_cache():
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout1():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.create_schedule(C.op)
s[C].reorder(C.op.axis[1], C.op.axis[0])
CC = s.cache_write(C, "global")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout2():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m*4, n), name='A')
B = tvm.placeholder((m*4, n), name='B')
C = tvm.compute(A.shape, lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.create_schedule(C.op)
x, y = C.op.axis
xo, xi = s[C].split(x, factor=4)
s[C].reorder(xo, y, xi)
CC = s.cache_write(C, "global")
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout3():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m*4, n), name='A')
B = tvm.placeholder((m*4, n), name='B')
k = tvm.reduce_axis((0, n), "k")
C = tvm.compute((A.shape[0],),
lambda i: tvm.sum(A(i, k) * B(i, k), axis=k), name='C')
s = tvm.create_schedule(C.op)
x = C.op.axis[0]
xo, xi = s[C].split(x, factor=4)
CC = s.cache_write(C, "global")
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__":
test_schedule_cache_relayout4()
test_schedule_cache_relayout3()
test_schedule_cache_relayout2()
test_schedule_cache_relayout1()
test_schedule_const_bound()
test_scan_inline1()
test_scan_inline2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment