Commit c8ec4111 by Tianqi Chen Committed by GitHub

[SCAN/Refactor] Refactor scan interface, enable fix point analysis. (#47)

parent 5198c100
......@@ -152,14 +152,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
......
......@@ -26,7 +26,9 @@ enum AttachType : int {
kNone = 0,
kRoot = 1,
kInline = 2,
kScope = 3
kInlinedAlready = 3,
kScope = 4,
kScanUpdate = 5
};
/*! \brief IterVar type */
......
......@@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;
static constexpr const char* _type_key = "Operation";
};
// Implementations of inline functions
......
......@@ -4,7 +4,7 @@ import sys
import tempfile
import subprocess
def compile_source(code, target="cubin"):
def compile_source(code, target="cubin", options=None):
"""Compile cuda code with NVCC from env.
Parameters
......@@ -12,9 +12,12 @@ def compile_source(code, target="cubin"):
code : str
The cuda code.
target: str
target : str
The target format
options : str
The additional options
Return
------
cubin : bytearray
......@@ -32,6 +35,8 @@ def compile_source(code, target="cubin"):
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
if options:
cmd += options
cmd += [path_code]
args = ' '.join(cmd)
......
......@@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0)
def scan(axis, init, update, state_placeholder, name="scan"):
def scan(init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
......@@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"):
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
......@@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)
......
......@@ -63,7 +63,8 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
# normalize schedule first
sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
......
......@@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule
......
......@@ -166,7 +166,15 @@ IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) {
x.include(set[i].cover_interval().as<IntervalSet>()->i);
IntSet s = set[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
} else if (can_prove(y.max + 1 >= x.min)) {
x.min = y.min;
} else {
x.include(y);
}
}
return IntervalSet::make(x);
}
......
......@@ -51,8 +51,6 @@ Operation PlaceholderOpNode::make(std::string name,
return Operation(n);
}
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
......@@ -162,24 +160,25 @@ Operation ScanOpNode::make(std::string name,
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
update[i]->shape[k], state_placeholder[i]->shape[k]));
if (k != 0) {
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->scan_axis = axis;
n->init = init;
......@@ -188,11 +187,14 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
IterVar scan_axis(
Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
name + ".idx");
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
......
......@@ -61,7 +61,9 @@ Stmt Inline(Stmt stmt,
Expr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
Stmt ret = IRInline(f, args, body).Mutate(stmt);
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file bound.cc
......@@ -259,11 +260,14 @@ void BoundProp(const Operation& op,
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, scan->init[i]->shape[0])));
}
if (update_dom) {
update_dom->data[0].push_back(dom_map.at(scan->scan_axis->var.get()));
}
// The update dimensions
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
if (init_dom) {
init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get()));
init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
if (update_dom) {
update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
......@@ -277,10 +281,12 @@ void BoundProp(const Operation& op,
}
}
// Given the bound of output of op
// Pass the bound to the related axis in op.
void GatherOpBound(const ScanOpNode* scan,
const Operation& op,
const FeedGraph& fg,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
CHECK(!rmap->count(scan->scan_axis));
......@@ -299,21 +305,29 @@ void GatherOpBound(const ScanOpNode* scan,
Range r = arith::Union(time_dom).cover_range(sdom);
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(scan, fg);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(op, body);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
const TensorDom& d = tmap.at(output[i]);
for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
CHECK(!rmap->count(sp_ax));
// In default, we always need all spatial axis
// Unless that axis only refers back to itself as a fixed point.
// TODO(tqchen): Add fix point detection.
(*rmap)[sp_ax] = sp_ax->dom;
CHECK(fix_pt.count(sp_ax));
if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
// fix point, we can slice it.
(*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
} else {
// not a fix point, need to include everything.
(*rmap)[sp_ax] = sp_ax->dom;
}
}
}
}
void GatherOpBound(const Operation& op,
const FeedGraph& fg,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) {
......@@ -329,7 +343,7 @@ void GatherOpBound(const Operation& op,
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
}
} else if (op.as<ScanOpNode>()) {
GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
GatherOpBound(op.as<ScanOpNode>(), op, fg, tmap, rmap);
} else if (op.as<PlaceholderOpNode>()) {
// dp nothing
} else {
......@@ -347,20 +361,14 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
}
// The map beteen tensor and operation it feeds ti
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
// AttachPath maps op-> a list of IterVar
// That represents the loop nest op sits in from inner most to outermost
using AttachPath = Map<Operation, Array<IterVar> >;
void InferRootBound(const Stage& stage,
const FeedGraph& feed_graph,
const AttachPath& attach_path,
std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : OutputRelatedIterVars(stage->op)) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
......@@ -368,11 +376,11 @@ void InferRootBound(const Stage& stage,
}
return;
}
// Infer root bounds for the attached node.
CHECK_EQ(stage->attach_type, kScope);
Stage parent = stage->attach_stage;
CHECK(parent.defined());
// parent stage, if any
Stage parent;
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) {
parent = stage->attach_stage;
}
// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
// consumers other than parent
......@@ -385,7 +393,7 @@ void InferRootBound(const Stage& stage,
auto it = feed_graph.find(t);
if (it != feed_graph.end()) {
for (const Operation& op : it->second) {
if (op != parent->op) {
if (!parent.defined() || op != parent->op) {
consumers.insert(op);
} else {
direct_consume_by_parent = true;
......@@ -404,16 +412,20 @@ void InferRootBound(const Stage& stage,
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
}
}
if (direct_consume_by_parent) {
// parent stage if exist
Stage parent = stage->attach_stage;
// Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
Range vrange = rmap->at(iv);
auto it = rmap->find(iv);
CHECK(it != rmap->end());
Range vrange = it->second;
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
<< " call schedule.normalize to achieve this. "
<< " stage=" << parent;
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
......@@ -464,8 +476,9 @@ void InferRootBound(const Stage& stage,
for (const Operation& op : consumers) {
std::unordered_map<const Variable*, IntSet> dom_map;
bool found = false;
Array<IterVar> attach = attach_path.at(stage->op);
for (IterVar iv : attach_path.at(op)) {
if (iv == stage->attach_ivar) {
if (attach.size() != 0 && iv == attach[0]) {
found = true; break;
}
Range vrange = rmap->at(iv);
......@@ -474,7 +487,7 @@ void InferRootBound(const Stage& stage,
<< "call schedule.normalize to achieve this.";
relax_set[iv->var.get()] = IntSet::range(vrange);
}
CHECK(found)
CHECK(found || attach.size() == 0)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : OutputRelatedIterVars(op)) {
......@@ -483,50 +496,15 @@ void InferRootBound(const Stage& stage,
}
BoundProp(op, dom_map, &tmap);
}
GatherOpBound(stage->op, tmap, rmap);
GatherOpBound(stage->op, feed_graph, tmap, rmap);
}
FeedGraph CreateFeedGraph(const Schedule& sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) {
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
auto g = CreateReadGraph(roots);
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
fg[t].push_back(kv.first);
}
}
return fg;
}
// Create AttachPath that maps op-> a list of IterVar
// That represents the loop nest op sits in from inner most to outermost
AttachPath CreateAttachPath(const Schedule& sch) {
AttachPath ret;
for (Stage stage : sch->stages) {
Array<IterVar> path;
for (Stage s = stage; s->attach_type == kScope;) {
IterVar attach_ivar = s->attach_ivar;
s = s->attach_stage;
bool start_attach = false;
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true;
if (start_attach) path.push_back(iv);
}
CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op;
}
ret.Set(stage->op, path);
}
return ret;
}
Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch);
FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret;
......@@ -535,6 +513,11 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
InferRootBound(stage, feed_graph, attach_path, &ret);
// pass down to get bound of all iter vars.
PassDown(stage, &ret);
// setup outer most threads.
for (IterVar iv : stage->outermost_threads) {
CHECK(iv->dom.defined());
ret[iv] = iv->dom;
}
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
......
......@@ -10,6 +10,46 @@
namespace tvm {
namespace schedule {
// key to specific tensor dimension.
struct TensorDimKey {
FunctionRef f;
int value_index;
int dim;
TensorDimKey() {}
TensorDimKey(const ir::Call* op, int dim)
: f(op->func), value_index(op->value_index), dim(dim) {
}
TensorDimKey(const Tensor& t, int dim)
: f(t->op), value_index(t->value_index), dim(dim) {
}
inline bool operator==(const TensorDimKey& other) const {
return f == other.f &&
value_index == other.value_index &&
dim == other.dim;
}
inline bool operator!=(const TensorDimKey& other) const {
return !operator==(other);
}
};
} // namespace schedule
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::schedule::TensorDimKey> {
std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index) << 32UL |
static_cast<size_t>(k.dim);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
namespace tvm {
namespace schedule {
// construct a read graph that gives readers of each operation
// that the root depend on
......@@ -28,7 +68,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
stack.pop_back();
Array<Tensor> deps;
if (op.as<ComputeOpNode>()) {
auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto fvisit = [&deps](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Operation call_op(call->func.node_);
......@@ -59,7 +99,6 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
return rmap;
}
void PostDFSOrder(const Operation& op,
const ReadGraph& g,
std::unordered_set<Operation>* visited,
......@@ -83,5 +122,269 @@ Array<Operation> PostDFSOrder(
return post_order;
}
FeedGraph CreateFeedGraph(const ReadGraph& g) {
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
fg[t].push_back(kv.first);
}
}
return fg;
}
AttachPath CreateAttachPath(Schedule sch) {
AttachPath ret;
for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
for (Stage stage : sch->stages) {
Array<IterVar> path;
for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) {
IterVar attach_ivar = s->attach_ivar;
s = s->attach_stage;
bool start_attach = false;
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true;
if (start_attach) path.push_back(iv);
}
CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op;
}
if (!ret.count(stage->op)) {
ret.Set(stage->op, path);
}
}
return ret;
}
// graph of push reach relation of tensor dimensions
using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
ReachGraph GetReachGraph(const Array<Operation>& ops) {
ReachGraph reach;
std::unordered_set<const Node*> bset;
for (size_t i = 0; i < ops.size(); ++i) {
bset.insert(ops[i].get());
}
for (Operation op : ops) {
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (size_t k = 1; k < update[i]->shape.size(); ++k) {
reach[TensorDimKey(t, k)].emplace_back(
TensorDimKey(update[i], k));
reach[TensorDimKey(t, k)].emplace_back(
TensorDimKey(init[i], k));
}
}
} else if (op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i);
reach[TensorDimKey(t, i)] = {};
}
auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
if (!bset.count(call->func.get())) return;
for (size_t i = 0; i < call->args.size(); ++i) {
TensorDimKey dkey(call, i);
auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
const Variable *v = node.as<Variable>();
auto it = vmap.find(v);
if (it != vmap.end()) {
reach[it->second].push_back(dkey);
}
};
ir::PostOrderVisit(call->args[i], fpush);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
}
}
return reach;
}
// Get all the operations that forms body of scan
void ScanGetBodyPostDFS_(
Operation op,
const ScanOpNode* scan,
const FeedGraph& feed_graph,
std::unordered_set<const Node*>* visited,
Array<Operation>* result) {
if (op.get() == scan) return;
bool empty_feed = true;
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = feed_graph.find(op.output(i));
if (it != feed_graph.end() && it->second.size()) {
empty_feed = false;
for (const Operation& xop : it->second) {
if (visited->count(xop.get())) continue;
visited->insert(xop.get());
ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result);
result->push_back(xop);
}
}
}
if (empty_feed && op.get() != scan) {
LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan";
}
}
Array<Operation> ScanGetBody_(
const ScanOpNode* scan,
const FeedGraph& feed_graph) {
CHECK(scan != nullptr);
std::unordered_set<const Node*> visited;
Array<Operation> result;
for (Tensor t : scan->state_placeholder) {
ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result);
}
return result;
}
Array<Operation> ScanGetBody(const Operation& scan) {
return ScanGetBody_(scan.as<ScanOpNode>(),
CreateFeedGraph(CreateReadGraph({scan})));
}
Map<IterVar, Expr> ScanFixPointAnalysis(
const Operation& scan_op, const Array<Operation>& body) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>();
CHECK(body[0].get() == scan);
std::unordered_map<TensorDimKey, const Node*> exact_reach;
std::unordered_set<const Node*> fail_set;
for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
TensorDimKey key(scan->state_placeholder[i], k);
exact_reach[key] = scan->spatial_axis_[sp_idx].get();
}
}
// merge exact reach
auto f_merge_key = [&exact_reach, &fail_set](
const TensorDimKey& dst, const TensorDimKey& src) {
auto sit = exact_reach.find(src);
if (sit == exact_reach.end()) return;
auto dit = exact_reach.find(dst);
if (dit == exact_reach.end()) {
exact_reach[dst] = sit->second;
} else {
if (dit->second != sit->second) {
fail_set.insert(dit->second);
fail_set.insert(sit->second);
}
}
};
// prop exact reach back.
for (size_t i = body.size(); i != 1; --i) {
const Operation& op = body[i - 1];
if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i);
for (size_t k = 1; i < update[i]->shape.size(); ++k) {
f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
}
}
} else if (op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i);
}
auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
for (size_t i = 0; i < call->args.size(); ++i) {
auto it = vmap.find(call->args[i].get());
TensorDimKey src(call, i);
if (it != vmap.end()) {
f_merge_key(it->second, src);
} else {
if (exact_reach.count(src)) {
fail_set.insert(exact_reach.at(src));
}
}
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
}
}
ReachGraph reach;
Map<IterVar, Expr> ret;
std::unordered_set<TensorDimKey> place_holder_ref;
for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
}
}
for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
TensorDimKey key(scan->update[i], k);
TensorDimKey target(scan->state_placeholder[i], k);
IterVar sp_iv = scan->spatial_axis_[sp_idx];
if (fail_set.count(sp_iv.get()) ||
!exact_reach.count(key) ||
exact_reach.at(key) != sp_iv.get()) {
ret.Set(sp_iv, make_const(Int(32), 0));
} else {
// now we proved exact match, need to prove no interference with other graph.
if (reach.size() == 0) reach = GetReachGraph(body);
// do a DFS
std::unordered_set<TensorDimKey> visited;
std::vector<TensorDimKey> stack{key};
visited.insert(key);
while (!stack.empty()) {
TensorDimKey k = stack.back();
if (k != target && place_holder_ref.count(k)) break;
stack.pop_back();
if (!reach.count(k)) {
LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
}
for (TensorDimKey kk : reach.at(k)) {
if (visited.count(kk)) {
continue;
}
visited.insert(kk);
stack.push_back(kk);
}
}
if (!stack.empty()) {
// failed the prove.
ret.Set(sp_iv, make_const(Int(32), 0));
} else {
ret.Set(sp_iv, make_const(Int(32), 1));
}
}
}
}
return ret;
}
} // namespace schedule
} // namespace tvm
......@@ -9,6 +9,7 @@
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace tvm {
......@@ -20,6 +21,16 @@ namespace schedule {
using ReadGraph = Map<Operation, Array<Tensor> >;
/*!
* \brief The map beteen tensor and operation it feeds to
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*!
* \brief AttachPath maps op-> a list of IterVar
*/
using AttachPath = Map<Operation, Array<IterVar> >;
/*!
* \brief Get read graph of each operation to all the
* Tensors that it directly depends on.
*
......@@ -41,6 +52,49 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots);
Array<Operation> PostDFSOrder(
const Array<Operation>& roots, const ReadGraph& g);
/*!
* \brief Create feedgraph for given Schedule
* \param g The read graph.
* \return The created feedgraph.
*/
FeedGraph CreateFeedGraph(const ReadGraph& g);
/*!
* \brief Create AttachPath that maps op-> a list of IterVar
* That represents the loop nest op sits in from inner most to outermost
* Also inserts attach_stage for scan updates when needed.
*
* \param sch The schedule.
* \return The attach path.
*/
AttachPath CreateAttachPath(Schedule sch);
/*!
* \brief Get all operations inside the recursion of scan.
* \param scan The scan node.
* \param feed_graph The feed graph to help analysis.
* \return The body operations, in read dependency order.
*/
Array<Operation> ScanGetBody_(
const ScanOpNode* scan, const FeedGraph& feed_graph);
// same as ScanGetBody_, but create FeedGraph internally.
Array<Operation> ScanGetBody(const Operation& scan);
/*!
* \brief Analyze each spatial dimension of scan's result.
* Give check on whether each dimension is fix point,
* An axis is a fixed point if it only refers back to itself in recursion
* and it is not used in axis of other recursion field.
*
* next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
*
* \param scan The scan node.
* \param body The body of scan, sorted in reverse PostDFSOrder.
* \return Map of spatial_axis -> IntImm
*/
Map<IterVar, Expr> ScanFixPointAnalysis(
const Operation& scan, const Array<Operation>& body);
} // namespace schedule
} // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file schedule_dataflow_rewrite.cc
*/
#include <tvm/schedule.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
namespace tvm {
// find first occurance location in leaf
template<typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
const Node* n = v.get();
for (size_t i = 0; i < array_node->data.size(); ++i) {
if (array_node->data[i].get() == n) return i;
}
return array_node->data.size();
}
using ir::TensorKey;
// The replacer of cache.
class TensorReplacer : public ir::IRMutator {
public:
explicit TensorReplacer(const std::unordered_map<TensorKey, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
ir::TensorKey key{op->func, op->value_index};
auto it = vmap_.find(key);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<TensorKey, Tensor>& vmap_;
};
class VarReplacer : public ir::IRMutator {
public:
explicit VarReplacer(
const std::unordered_map<const Variable*, Expr>& vsub)
: vsub_(vsub) {}
Expr Mutate_(const Variable* op, const Expr& e) {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
return e;
}
private:
const std::unordered_map<const Variable*, Expr>& vsub_;
};
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages,
std::unordered_map<TensorKey, Tensor>* vmap) {
for (Stage s : stages) {
if (s->op.as<ComputeOpNode>()) {
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
TensorReplacer repl(*vmap);
Expr body = repl.Mutate(compute->body);
if (repl.found) {
Operation op = ComputeOpNode::make(
compute->name, compute->axis, body);
(*vmap)[TensorKey{s->op, 0}] = op.output(0);
s->op = op;
}
} else if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
std::shared_ptr<ScanOpNode> n =
std::make_shared<ScanOpNode>(*scan);
// copy on write semantics ganrantees correctness
for (size_t i = 0; i < n->init.size(); ++i) {
TensorKey key{n->init[i]->op, n->init[i]->value_index};
if (vmap->count(key)) {
n->init.Set(i, vmap->at(key));
}
}
for (size_t i = 0; i < n->update.size(); ++i) {
TensorKey key{n->update[i]->op, n->update[i]->value_index};
if (vmap->count(key)) {
n->update.Set(i, vmap->at(key));
}
}
if (!n->init.same_as(scan->init) ||
!n->update.same_as(scan->update)) {
Operation op(n);
for (int i = 0; i < op->num_outputs(); ++i) {
(*vmap)[TensorKey{s->op, i}] = op.output(i);
}
s->op = op;
}
} else if (s->op.as<PlaceholderOpNode>()) {
} else {
LOG(FATAL) << "unhandled problem";
}
}
}
Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers) {
// create identity mapping.
std::ostringstream os;
os << tensor->op->name;
if (tensor->op->num_outputs() != 1) {
os << ".v" << tensor->value_index;
}
os << "." << scope;
Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
return tensor(Array<Expr>(i.begin(), i.end()));
}, os.str());
std::unordered_map<TensorKey, Tensor> vsub;
vsub[TensorKey{tensor->op, tensor->value_index}] = cache;
std::unordered_map<TensorKey, Tensor> vmap;
for (Operation op : readers) {
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute)
<< "cache read only take ComputeOp as readers";
Stage s = operator[](op);
compute = s->op.as<ComputeOpNode>();
TensorReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
CHECK(repl.found)
<< "Cannot find " << tensor
<< " in the body of specified reader " << op;
Operation repl_op = ComputeOpNode::make(
compute->name, compute->axis, body);
vmap[TensorKey{s->op, 0}] = repl_op.output(0);
s->op = repl_op;
}
ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op));
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage);
return cache;
}
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK_EQ(orig_stage->relations.size(), 0U)
<< "Create cache_write before doing split/fuse/reorder";
compute = orig_stage->op.as<ComputeOpNode>();
CHECK(compute);
Array<Expr> args;
Array<IterVar> new_axis;
std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute->axis) {
args.push_back(iv->var);
IterVar new_iv(iv->dom, iv->var->name_hint + ".c");
new_axis.push_back(new_iv);
vsub[iv->var.get()] = new_iv->var;
}
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, body);
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis,
cache_tensor(args));
std::unordered_map<TensorKey, Tensor> vmap;
vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
// create schedule for new cached stage.
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage);
return cache_tensor;
}
void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : sch->stages) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
if (s->op.as<ScanOpNode>()) {
attach_mark[s.get()] = 1;
}
}
for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindNodeRef(leaf_vars, iv);
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb");
s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased;
}
}
}
// remap the parent relation
for (Stage s : sch->stages) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
void SetScanAttach(const Schedule& sch) { // NOLINT(*)
for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
}
void InjectInline(const Schedule& sch) {
std::vector<Expr> new_body(sch->stages.size());
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready;
Array<Var> args;
Expr body;
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
CHECK(compute)
<< "can only inline compute op";
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
body = compute->body;
}
for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) {
if (!new_body[j].defined()) {
new_body[j] = s->op.as<ComputeOpNode>()->body;
}
new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]),
stage->op, args, body).as<ir::Evaluate>()->value;
}
}
}
}
std::unordered_map<TensorKey, Tensor> repl;
// rewrite dataflow
for (size_t i = 0; i < sch->stages.size(); ++i) {
if (new_body[i].defined() &&
!new_body[i].same_as(sch->stages[i]->op)) {
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute);
Operation op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]);
repl[TensorKey{sch->stages[i]->op, 0}] = op.output(0);
Stage s = sch->stages[i];
s->op = op;
}
}
ReplaceDataFlow(sch->stages, &repl);
}
void Schedule::normalize() {
RebaseNonZeroMinLoop(*this);
SetScanAttach(*this);
InjectInline(*this);
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.cc
* \file schedule_lang.cc
*/
#include <tvm/schedule.h>
#include <tvm/ir_mutator.h>
......@@ -37,6 +37,10 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
void Split(StageNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) {
if (self->attach_type == kScanUpdate) {
CHECK(!parent.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
......@@ -83,6 +87,8 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
}
Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates";
(*this)->attach_type = kScope;
(*this)->attach_ivar = scope;
(*this)->attach_stage = parent;
......@@ -93,16 +99,22 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
CHECK(found)
<< "Cannot find the axis in parent's leaf_iter_vars or outermost_threads";
<< "Cannot find the axis " << scope
<< " in parent's leaf_iter_vars or outermost_threads:"
<< " parent=" << parent;
return *this;
}
Stage& Stage::compute_inline() { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates";
(*this)->attach_type = kInline;
return *this;
}
Stage& Stage::compute_root() { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates";
(*this)->attach_type = kRoot;
return *this;
}
......@@ -128,9 +140,15 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor
}
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
StageNode* self = operator->();
if (self->attach_type == kScanUpdate) {
CHECK(!inner.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
CHECK(!outer.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
*p_target = fused;
StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
......@@ -157,6 +175,10 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
std::vector<size_t> pos;
for (size_t i = 0; i < order.size(); ++i) {
if ((*this)->attach_type == kScanUpdate) {
CHECK(!order[i].same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
}
std::vector<std::shared_ptr<Node> > temp;
......@@ -239,12 +261,25 @@ Schedule::Schedule(Array<Operation> ops) {
stage->is_output = output_set.count(op);
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op];
s->attach_type = kScanUpdate;
s->attach_stage = stage;
}
}
}
node_ = std::move(n);
}
Stage Schedule::operator[](const Operation& op) {
return (*this)->stage_map.at(op);
auto it = (*this)->stage_map.find(op);
CHECK(it != (*this)->stage_map.end())
<< "Cannot find Stage for operator " << op
<< " in the schedule";
return (*it).second;
}
IterVarRelation SplitNode::make(
......@@ -274,42 +309,6 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n);
}
void Schedule::normalize() {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : (*this)->stages) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : (*this)->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindNodeRef(leaf_vars, iv);
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb");
s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased;
}
}
}
// remap the parent relation
for (Stage s : (*this)->stages) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
IterVarAttr::IterVarAttr(IterVarType t) {
std::shared_ptr<IterVarAttrNode> n = std::make_shared<IterVarAttrNode>();
n->iter_type = t;
......@@ -323,190 +322,4 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
using ir::TensorKey;
// The replacer of cache.
class TensorReplacer : public ir::IRMutator {
public:
TensorReplacer(const std::unordered_map<TensorKey, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
ir::TensorKey key{op->func, op->value_index};
auto it = vmap_.find(key);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<TensorKey, Tensor>& vmap_;
};
class VarReplacer : public ir::IRMutator {
public:
explicit VarReplacer(
const std::unordered_map<const Variable*, Expr>& vsub)
: vsub_(vsub) {}
Expr Mutate_(const Variable* op, const Expr& e) {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
return e;
}
private:
const std::unordered_map<const Variable*, Expr>& vsub_;
};
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages,
std::unordered_map<TensorKey, Tensor>* vmap) {
for (Stage s : stages) {
if (s->op.as<ComputeOpNode>()) {
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
TensorReplacer repl(*vmap);
Expr body = repl.Mutate(compute->body);
if (repl.found) {
Operation op = ComputeOpNode::make(
compute->name, compute->axis, body);
(*vmap)[TensorKey{s->op, 0}] = op.output(0);
s->op = op;
}
} else if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
std::shared_ptr<ScanOpNode> n =
std::make_shared<ScanOpNode>(*scan);
// copy on write semantics ganrantees correctness
for (size_t i = 0; i < n->init.size(); ++i) {
TensorKey key{n->init[i]->op, n->init[i]->value_index};
if (vmap->count(key)) {
n->init.Set(i, vmap->at(key));
}
}
for (size_t i = 0; i < n->update.size(); ++i) {
TensorKey key{n->update[i]->op, n->update[i]->value_index};
if (vmap->count(key)) {
n->update.Set(i, vmap->at(key));
}
}
if (!n->init.same_as(scan->init) ||
!n->update.same_as(scan->update)) {
Operation op(n);
for (int i = 0; i < op->num_outputs(); ++i) {
(*vmap)[TensorKey{s->op, i}] = op.output(i);
}
s->op = op;
}
} else if (s->op.as<PlaceholderOpNode>()) {
} else {
LOG(FATAL) << "unhandled problem";
}
}
}
Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers) {
// create identity mapping.
std::ostringstream os;
os << tensor->op->name;
if (tensor->op->num_outputs() != 1) {
os << ".v" << tensor->value_index;
}
os << "." << scope;
Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
return tensor(Array<Expr>(i.begin(), i.end()));
}, os.str());
std::unordered_map<TensorKey, Tensor> vsub;
vsub[TensorKey{tensor->op, tensor->value_index}] = cache;
std::unordered_map<TensorKey, Tensor> vmap;
for (Operation op : readers) {
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute)
<< "cache read only take ComputeOp as readers";
Stage s = operator[](op);
compute = s->op.as<ComputeOpNode>();
TensorReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
CHECK(repl.found)
<< "Cannot find " << tensor
<< " in the body of specified reader" << op;
Operation repl_op = ComputeOpNode::make(
compute->name, compute->axis, body);
vmap[TensorKey{s->op, 0}] = repl_op.output(0);
s->op = repl_op;
}
ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op));
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage);
return cache;
}
Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) {
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute)
<< "cache write only take ComputeOp as writers";
CHECK(!orig_stage.is_scheduled())
<< "Create cache_write before doing split/fuse/reorder";
compute = orig_stage->op.as<ComputeOpNode>();
CHECK(compute);
Array<Expr> args;
Array<IterVar> new_axis;
std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute->axis) {
args.push_back(iv->var);
IterVar new_iv(iv->dom, iv->var->name_hint + ".c");
new_axis.push_back(new_iv);
vsub[iv->var.get()] = new_iv->var;
}
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, body);
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis,
cache_tensor(args));
std::unordered_map<TensorKey, Tensor> vmap;
vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
// create schedule for new cached stage.
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos,
cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage);
return cache_tensor;
}
} // namespace tvm
......@@ -369,7 +369,7 @@ Stmt MakeRealize(const ScanOpNode* op,
CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds;
bounds.push_back(tdom);
for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
for (size_t k = 1; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = op->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
......@@ -561,6 +561,7 @@ class InjectScanStep : public IRMutator {
Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute != nullptr)
<< "can only inline compute op";
......@@ -614,7 +615,7 @@ class SchedulePostProc : public IRMutator {
if (it->second.defined()) {
Stmt ret = AttrStmt::make(
it->second, op->type_key, op->value, op->body);
return this->Mutate_(ret.as<AttrStmt>(), ret);
return this->Mutate(ret);
} else {
return this->Mutate(op->body);
}
......@@ -631,7 +632,7 @@ class SchedulePostProc : public IRMutator {
Stmt ret = Realize::make(
it->second->op, it->second->value_index,
op->type, op->bounds, op->condition, op->body);
return this->Mutate_(ret.as<Realize>(), ret);
return this->Mutate(ret);
} else {
return this->Mutate(op->body);
}
......@@ -644,11 +645,10 @@ class SchedulePostProc : public IRMutator {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second.first;
const Tensor& dst = it->second;
Stmt ret = Provide::make(
dst->op, dst->value_index, op->value,
RewriteArgs(it->second.second, op->args));
return IRMutator::Mutate_(ret.as<Provide>(), ret);
dst->op, dst->value_index, op->value, op->args);
return this->Mutate(ret);
} else {
return IRMutator::Mutate_(op, s);
}
......@@ -659,12 +659,11 @@ class SchedulePostProc : public IRMutator {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second.first;
const Tensor& dst = it->second;
Expr ret = Call::make(
op->type, dst->op->name,
RewriteArgs(it->second.second, op->args),
op->type, dst->op->name, op->args,
op->call_type, dst->op, dst->value_index);
return IRMutator::Mutate_(ret.as<Call>(), ret);
return this->Mutate(ret);
}
}
return IRMutator::Mutate_(op, e);
......@@ -685,14 +684,14 @@ class SchedulePostProc : public IRMutator {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
for (size_t i = 0; i < scan->update.size(); ++i) {
Tensor t = s->origin_op.output(i);
AddReplace(scan->init[i], t, Expr());
AddReplace(scan->update[i], t, scan->scan_axis->var);
AddReplace(scan->state_placeholder[i], t, Expr());
AddReplace(scan->init[i], t);
AddReplace(scan->update[i], t);
AddReplace(scan->state_placeholder[i], t);
}
} else if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0);
AddReplace(s->op.output(0), target,
Expr(), target, s->origin_op);
target, s->origin_op);
}
}
}
......@@ -700,26 +699,17 @@ class SchedulePostProc : public IRMutator {
private:
void AddReplace(Tensor src,
Tensor dst,
Expr head_idx,
Tensor repl_realize = Tensor(),
Operation repl_op = Operation()) {
TensorKey key{src->op, src->value_index};
replace_buffer_[key] = std::make_pair(dst, head_idx);
replace_buffer_[key] = dst;
replace_realize_[key] = repl_realize;
replace_op_[src->op.get()] = repl_op;
}
Array<Expr> RewriteArgs(Expr head, Array<Expr> args) {
if (!head.defined()) return args;
Array<Expr> new_args{head};
for (Expr e : args) {
new_args.push_back(e);
}
return new_args;
}
// The scan value
std::unordered_map<const Variable*, Expr> var_value_;
// buffer replacement
std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_buffer_;
std::unordered_map<TensorKey, Tensor> replace_buffer_;
// buffere realization to be replaced
std::unordered_map<TensorKey, Tensor> replace_realize_;
// replace producer consumer.
......@@ -755,10 +745,13 @@ Stmt ScheduleOps(
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
CHECK_NE(s->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
if (scan_attach.count(s->op)) {
CHECK(s->attach_type == kNone || s->attach_type == kInline)
CHECK(s->attach_type == kNone ||
s->attach_type == kScanUpdate)
<< "Cannot specify compute_at for scan's init/update";
CHECK(body.defined());
const auto& p = scan_attach.at(s->op);
......@@ -766,8 +759,8 @@ Stmt ScheduleOps(
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update";
} else if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kInlinedAlready) {
// do nothing
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
......
......@@ -8,8 +8,8 @@ def test_scan():
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
# schedule
s = tvm.Schedule(res.op)
......@@ -18,7 +18,7 @@ def test_scan():
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_update].split(x, outer=thread_x)
# one line to build the function.
......
......@@ -40,9 +40,8 @@ def test_tensor_scan():
t = tvm.IterVar((1, m), "t")
x = tvm.placeholder((m, n))
s = tvm.placeholder((m, n))
res = tvm.scan(t,
tvm.compute((1, n), lambda _, i: x[0, i]),
tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]),
res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]),
tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]),
s)
assert tuple(res.shape) == (m, n)
......
......@@ -50,25 +50,30 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_scan():
m = tvm.Var("m")
n = tvm.Var("n")
X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
s_scan = tvm.scan(s_init, s_update, s_state)
def test_create_read_graph():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
assert tuple(s_scan.shape) == (m, n)
g = tvm.schedule.CreateReadGraph([A2.op])
s = tvm.Schedule(s_scan.op)
XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo)
assert g[A2.op][0] == A1
assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder([A2.op], g)
assert(post_order[0] == A.op)
assert(post_order[1] == A1.op)
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
assert bounds[XX.op.axis[1]].extent.value == 4
if __name__ == "__main__":
test_create_read_graph()
test_bound_scan()
test_bound3()
test_bound1()
test_bound2()
import tvm
def test_scan():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init")
x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans")
s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1")
s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
def test_getbody():
body = tvm.schedule.ScanGetBody(s_scan.op)
assert set(body) == set([s_scan.op, s_update.op, s_up1.op])
def test_attach_path():
s = tvm.Schedule(s_scan.op)
s[x_trans].compute_at(s[s_update], s_update.op.axis[0])
apath = tvm.schedule.CreateAttachPath(s)
assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis]))
assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis]))
def test_fix_pt():
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
def test_scan_fix_point():
m = tvm.Var("m")
n = tvm.Var("n")
l = tvm.Var("l")
x = tvm.compute((l, m, n), lambda *i: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((l, m, n))
s_init = tvm.compute((1, m, n), lambda _, i, j: x[0, i, j], name="s_init")
def test_scan0():
s_update = tvm.compute((l, m, n),
lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1)
def test_scan1():
s_update = tvm.compute((l, m, n),
lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan3_not_exact_reach():
s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1")
s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1")
s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan4_reach_other():
s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1")
s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1")
s_update = tvm.compute((l, m, n),
lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
test_scan0()
test_scan1()
test_scan3_not_exact_reach()
test_scan4_reach_other()
def test_create_read_graph():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
g = tvm.schedule.CreateReadGraph([A2.op])
assert g[A2.op][0] == A1
assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder([A2.op], g)
assert(post_order[0] == A.op)
assert(post_order[1] == A1.op)
if __name__ == "__main__":
test_scan()
test_create_read_graph()
test_scan_fix_point()
......@@ -43,13 +43,11 @@ def test_schedule2():
def test_schedule_scan():
m = tvm.Var("m")
n = tvm.Var("n")
l = tvm.Var("l")
t = tvm.IterVar((1, m), name="t")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
res = tvm.scan(s_init, s_update, s_state)
assert tuple(res.shape) == (m, n)
s = tvm.Schedule(res.op)
......@@ -59,7 +57,6 @@ def test_schedule_scan():
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_auto_inline():
m = tvm.Var('m')
n = tvm.Var('n')
......@@ -71,9 +68,27 @@ def test_auto_inline():
s = tvm.Schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_mixed():
n = tvm.Var('n')
A = tvm.placeholder((n, ), name='A')
A1 = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='A1')
A2 = tvm.compute(A.shape, lambda *i: A1(*i) + 2, name='A2')
C = tvm.compute((n,), lambda i: A2[i] + A1[i], name='C')
s = tvm.Schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=8)
s[A1].compute_at(s[C], xo)
s[A2].compute_inline()
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule_cache():
m = tvm.Var('m')
n = tvm.Var('n')
......@@ -90,9 +105,10 @@ def test_schedule_cache():
if __name__ == "__main__":
test_inline_mixed()
test_auto_inline()
test_schedule_scan()
test_schedule0()
test_schedule1()
test_schedule2()
test_auto_inline()
test_schedule_cache()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment