Commit 0f693212 by tqchen

Pass first basic case of bound inference

parent c5395a1f
Subproject commit adaea9e85bc0a213d4eb63edfa4762f2147c73ec Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include "../schedule/bound.h"
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -20,8 +21,16 @@ using RetValue = APIVariantValue; ...@@ -20,8 +21,16 @@ using RetValue = APIVariantValue;
*ret = PassName(args.at(0)); \ *ret = PassName(args.at(0)); \
}) \ }) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \
}) \
REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
...@@ -9,6 +9,11 @@ ...@@ -9,6 +9,11 @@
namespace tvm { namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")";
});
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include "./int_set.h" #include "./int_set.h"
#include "./bound.h" #include "./bound.h"
#include "./graph.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -62,7 +63,7 @@ void PassDown(const Schedule& s, ...@@ -62,7 +63,7 @@ void PassDown(const Schedule& s,
// pass the integer set on each leave loop up to the root // pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar. // dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction. // dom_map can be used to get cached result in reverse construction.
void PassUp(const Schedule& s, void PassUp(const ScheduleNode* s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) { std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state; auto& state = *p_state;
...@@ -89,62 +90,145 @@ void PassUp(const Schedule& s, ...@@ -89,62 +90,145 @@ void PassUp(const Schedule& s,
} }
} }
void PassBound( /*!
* \brief Pass the bound of tensor read
* to the corresponding bound of the IterVar of operation
* \param tensor The tensor to be passed.
* \param dim_bounds The read index set on each dimension.
* \param The result IterVar bound .
*/
void PassToOperation(
const Tensor& tensor, const Tensor& tensor,
const std::vector<IntSet>& arg_bounds, const std::vector<IntSet>& dim_bounds,
std::unordered_map<IterVar, std::vector<IntSet> >* result) { std::unordered_map<IterVar, std::vector<IntSet> >* result) {
if (tensor->op.as<ComputeOpNode>()) { if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars(); auto root_iter_vars = tensor->op->root_iter_vars();
CHECK_EQ(tensor.ndim(), root_iter_vars.size()); CHECK_EQ(tensor.ndim(), root_iter_vars.size());
for (size_t i = 0; i < tensor.ndim(); ++i) { for (size_t i = 0; i < tensor.ndim(); ++i) {
(*result)[root_iter_vars[i]].push_back(arg_bounds[i]); (*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
} }
} else { } else {
LOG(FATAL) << "unknown operation mode"; LOG(FATAL) << "unknown operation mode";
} }
} }
void PassBound( /*!
Operation op, * \brief Recursively propagate bound
std::unordered_map<IterVar, IntSet>* ebound) { * \param post_order The propagation order.
if (op.as<ComputeOpNode>()) { * \param dom_map The domain map to be propagated
auto fvisit = [ebound](const NodeRef& n) { * \return The result bound
auto *call = n.as<ir::Call>(); */
if (call != nullptr && call->func.defined()) { std::unordered_map<IterVar, IntSet>
Tensor t(call->func.node_); BoundProp(const Array<Operation>& post_order,
std::vector<IntSet> arg_bounds; std::unordered_map<IterVar, std::vector<IntSet> > *p_state) {
for (size_t i = 0; i < t.ndim(); ++i) { std::unordered_map<IterVar, IntSet> result;
arg_bounds.push_back(Eval(call->args[i], *ebound));
} for (size_t i = post_order.size(); i != 0; --i) {
Operation op = post_order[i - 1];
if (op.as<ComputeOpNode>()) {
for (auto iv : op->root_iter_vars()) {
CHECK(p_state->count(iv))
<< "Bound of root operator must exists";
CHECK(!result.count(iv));
result[iv] = Union(p_state->at(iv));
} }
}; auto fvisit = [p_state, &result](const NodeRef& n) {
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); auto *call = n.as<ir::Call>();
} else { if (call != nullptr && call->func.defined()) {
LOG(FATAL) << "unknown operation mode"; Tensor t(call->func.node_);
if (t->op.defined()) {
std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result));
}
PassToOperation(t, arg_bounds, p_state);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else {
LOG(FATAL) << "unknown operation mode";
}
} }
return result;
} }
void InferBound(const Schedule& sch,
std::unordered_map<IterVar, Range>* rmap) { // check if scope
CHECK_NE(sch->attach_type, kNone); bool ScopeRelax(const IterVar& iv, const std::string& scope) {
if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false;
static std::unordered_map<std::string, int> scope_rank{
{"global", 0},
{"shared", 1},
{"local", 2}
};
return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag);
}
void InferBound(
const ScheduleNode* parent,
const Schedule& sch,
std::unordered_map<IterVar, Range>* rmap) {
if (sch->attach_type == kInline) return; if (sch->attach_type == kInline) return;
if (sch->attach_type == kRoot) { if (sch->attach_type == kRoot || sch->attach_type == kNone) {
auto root_iter_vars = sch->op->root_iter_vars(); auto root_iter_vars = sch->op->root_iter_vars();
for (size_t i = 0; i < root_iter_vars.size(); ++i) { for (auto iv : root_iter_vars) {
auto v = root_iter_vars[i]; CHECK(iv->dom.defined());
CHECK(v->dom.defined()); CHECK(!rmap->count(iv));
CHECK(!rmap->count(v)); (*rmap)[iv] = iv->dom;
(*rmap)[v] = v->dom;
} }
} }
// get range of all child iter vars. // get range of all child iter vars.
PassDown(sch, rmap); PassDown(sch, rmap);
// pass iteration variable to children
if (sch->attach_type == kScope) {
CHECK(parent != nullptr);
auto g = CreateReadGraph(parent->op);
auto post_order = PostDFSOrder(parent->op, g);
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
if (fix_value && !ScopeRelax(iv, sch->scope)) {
up_state[iv] = IntSet::make_point(iv->var);
} else {
up_state[iv] = IntSet::make_range(rmap->at(iv));
}
if (sch->attach_parent == iv) {
fix_value = false;
}
}
// get the bound of the root IterVars given the current condition
PassUp(parent, *rmap, &up_state);
std::unordered_map<IterVar, std::vector<IntSet> > bp_state;
for (auto iv : parent->op->root_iter_vars()) {
CHECK(up_state.count(iv));
bp_state[iv] = {up_state.at(iv)};
}
auto result = BoundProp(post_order, &bp_state);
for (auto iv : sch->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange();
}
}
// also call infer bound on children
for (Schedule child : sch->children) {
InferBound(sch.operator->(), child, rmap);
}
} }
Map<IterVar, Range> InferBound(Schedule sch) { Map<IterVar, Range> InferBound(Schedule sch) {
return {}; std::unordered_map<IterVar, Range> ret;
CHECK(sch->attach_type != kInline && sch->attach_type != kScope)
<< "the Schedule is not a root Schedule";
InferBound(nullptr, sch, &ret);
return Map<IterVar, Range>(ret.begin(), ret.end());
} }
} // namespace schedule } // namespace schedule
......
...@@ -14,26 +14,29 @@ namespace schedule { ...@@ -14,26 +14,29 @@ namespace schedule {
// construct a read graph that gives readers of each operation // construct a read graph that gives readers of each operation
// that the root depend on // that the root depend on
ReadGraph CreateReadGraph(Operation root) { ReadGraph CreateReadGraph(const Operation& root) {
std::unordered_map<Operation, std::vector<Tensor> > rmap; ReadGraph rmap;
rmap[root] = {};
std::vector<Operation> stack{root}; std::vector<Operation> stack{root};
std::unordered_set<const Node*> visited{root.get()};
while (!stack.empty()) { while (!stack.empty()) {
Operation r = stack.back(); Operation op = stack.back();
stack.pop_back(); stack.pop_back();
auto& vec = rmap.at(r); Array<Tensor> deps;
if (r.as<ComputeOpNode>()) { if (op.as<ComputeOpNode>()) {
auto fvisit = [&vec, &rmap, &stack](const NodeRef& n) { auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto *call = n.as<ir::Call>(); auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_); Tensor t(call->func.node_);
vec.push_back(t); deps.push_back(t);
if (t->op.defined() && rmap.count(t->op) == 0) { if (t->op.defined() && visited.count(t->op.get()) == 0) {
rmap[t->op] = {}; stack.push_back(t->op); visited.insert(t->op.get());
stack.push_back(t->op);
} }
} }
}; };
ir::PostOrderVisit(r.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else { } else {
LOG(FATAL) << "unknown operation mode"; LOG(FATAL) << "unknown operation mode";
} }
...@@ -43,9 +46,9 @@ ReadGraph CreateReadGraph(Operation root) { ...@@ -43,9 +46,9 @@ ReadGraph CreateReadGraph(Operation root) {
void PostDFSOrder(const Operation& op, void PostDFSOrder(const Operation& op,
const ReadGraph& g, const ReadGraph& g,
std::unordered_set<Operation>* visited, std::unordered_set<Operation>* visited,
std::vector<Operation>* post_order) { Array<Operation>* post_order) {
visited->insert(op); visited->insert(op);
for (const auto& t : g.at(op)) { for (const auto& t : g.at(op)) {
if (t->op.defined() && !visited->count(t->op)) { if (t->op.defined() && !visited->count(t->op)) {
...@@ -55,10 +58,10 @@ void PostDFSOrder(const Operation& op, ...@@ -55,10 +58,10 @@ void PostDFSOrder(const Operation& op,
post_order->push_back(op); post_order->push_back(op);
} }
std::vector<Operation> PostDFSOrder( Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g) { const Operation& root, const ReadGraph& g) {
std::unordered_set<Operation> visited; std::unordered_set<Operation> visited;
std::vector<Operation> post_order; Array<Operation> post_order;
PostDFSOrder(root, g, &visited, &post_order); PostDFSOrder(root, g, &visited, &post_order);
return post_order; return post_order;
} }
......
...@@ -17,7 +17,7 @@ namespace schedule { ...@@ -17,7 +17,7 @@ namespace schedule {
/*! /*!
* \brief data structure of Operation->Tensors it reads * \brief data structure of Operation->Tensors it reads
*/ */
using ReadGraph = std::unordered_map<Operation, std::vector<Tensor> >; using ReadGraph = Map<Operation, Array<Tensor> >;
/*! /*!
* \brief Get read graph of each operation to all the * \brief Get read graph of each operation to all the
...@@ -38,7 +38,7 @@ ReadGraph CreateReadGraph(const Operation& root); ...@@ -38,7 +38,7 @@ ReadGraph CreateReadGraph(const Operation& root);
* \note PostDFSOrder is a special case of Topoligical order, * \note PostDFSOrder is a special case of Topoligical order,
* and can be used when topoligical order is needed. * and can be used when topoligical order is needed.
*/ */
std::vector<Operation> PostDFSOrder( Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g); const Operation& root, const ReadGraph& g);
} // namespace schedule } // namespace schedule
......
...@@ -176,17 +176,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -176,17 +176,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ')'; p->stream << ')';
}); });
IntSet IntSet::make(Range dom) { IntSet IntSet::make_range(Range dom) {
auto n = std::make_shared<IntSetNode>(); auto n = std::make_shared<IntSetNode>();
n->base = dom; n->base = dom;
return IntSet(n); return IntSet(n);
} }
Range IntSet::GetCoverRange() const {
const IntSetNode* s = operator->();
CHECK(s != nullptr) << "empty set";
if (s->domain.size() == 0 && s->concrete.size() == 0) {
return s->base;
}
LOG(FATAL) << "not yet implemented";
return Range();
}
IntSet IntSet::make_point(Expr point) {
return IntSet::make_range(Range::make_with_min_extent(point, 1));
}
IntSet IntSet::make_all_set() { IntSet IntSet::make_all_set() {
LOG(FATAL) << "TODO"; LOG(FATAL) << "TODO";
return IntSet(); return IntSet();
} }
IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
LOG(FATAL) << "TODO";
return IntSet();
}
void PassUp(const SplitNode* s, void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer, const IntSet& outer,
...@@ -197,7 +217,7 @@ void PassUp(const SplitNode* s, ...@@ -197,7 +217,7 @@ void PassUp(const SplitNode* s,
dom_map.count(s->parent) && dom_map.count(s->parent) &&
Match(outer, dom_map.at(s->outer)) && Match(outer, dom_map.at(s->outer)) &&
Match(inner, dom_map.at(s->inner))) { Match(inner, dom_map.at(s->inner))) {
*parent = IntSet::make(dom_map.at(s->parent)); *parent = IntSet::make_range(dom_map.at(s->parent));
return; return;
} }
// copy construct // copy construct
...@@ -230,21 +250,21 @@ void PassUp(const FuseNode* s, ...@@ -230,21 +250,21 @@ void PassUp(const FuseNode* s,
CHECK(dom_map.count(s->fused)); CHECK(dom_map.count(s->fused));
if (Match(fused, dom_map.at(s->fused))) { if (Match(fused, dom_map.at(s->fused))) {
*outer = IntSet::make(dom_map.at(s->outer)); *outer = IntSet::make_range(dom_map.at(s->outer));
*inner = IntSet::make(dom_map.at(s->inner)); *inner = IntSet::make_range(dom_map.at(s->inner));
return; return;
} }
if (IsNumber(fused)) { if (IsNumber(fused)) {
Expr value = AsNumber(fused); Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->outer)->extent;
*outer = IntSet::make(Range::make_with_min_extent(value / factor, 1)); *outer = IntSet::make_point(value / factor);
*inner = IntSet::make(Range::make_with_min_extent(value % factor, 1)); *inner = IntSet::make_point(value % factor);
} else { } else {
LOG(WARNING) << "use fallback inference rule in fuse"; LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced. // simply use the entire set, this rule can be enhanced.
*outer = IntSet::make(dom_map.at(s->outer)); *outer = IntSet::make_range(dom_map.at(s->outer));
*inner = IntSet::make(dom_map.at(s->inner)); *inner = IntSet::make_range(dom_map.at(s->inner));
return; return;
} }
} }
...@@ -272,7 +292,7 @@ class IRSetEvaluator { ...@@ -272,7 +292,7 @@ class IRSetEvaluator {
}; };
inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) { inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make(Range::make_with_min_extent(e, 1)); return IntSet::make_point(e);
} }
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
...@@ -286,7 +306,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) ...@@ -286,7 +306,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
if (it != m->dom_map.end()) { if (it != m->dom_map.end()) {
return it->second; return it->second;
} else { } else {
return IntSet::make(Range::make_with_min_extent(e, 1)); return IntSet::make_point(e);
} }
}); });
...@@ -298,10 +318,9 @@ inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) { ...@@ -298,10 +318,9 @@ inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) {
if (IsNumber(a) && IsNumber(b)) { if (IsNumber(a) && IsNumber(b)) {
if (Match(a, op->a) && if (Match(a, op->a) &&
Match(b, op->b)) { Match(b, op->b)) {
return IntSet::make(Range::make_with_min_extent(e, 1)); return IntSet::make_point(e);
} else { } else {
return IntSet::make(Range::make_with_min_extent( return IntSet::make_point(T::make(AsNumber(a), AsNumber(b)));
T::make(AsNumber(a), AsNumber(b)), 1));
} }
} else { } else {
return BinaryCombine<T>(a, b); return BinaryCombine<T>(a, b);
...@@ -319,7 +338,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) ...@@ -319,7 +338,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
// use simply bound for logical expressions for now. // use simply bound for logical expressions for now.
inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) { inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make(Range::make_with_min_extent(0, 2)); return IntSet::make_range(Range::make_with_min_extent(0, 2));
} }
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
...@@ -334,8 +353,8 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) ...@@ -334,8 +353,8 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
} // namespace } // namespace
IntSet Eval(Expr e, IntSet EvalSet(Expr e,
const std::unordered_map<IterVar, IntSet>& dom_map) { const Map<IterVar, IntSet>& dom_map) {
IRSetEvaluator m; IRSetEvaluator m;
for (auto kv : dom_map) { for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second; m.dom_map[kv.first->var.as<Variable>()] = kv.second;
......
...@@ -29,6 +29,10 @@ class IntSet : public NodeRef { ...@@ -29,6 +29,10 @@ class IntSet : public NodeRef {
return !defined(); return !defined();
} }
/*! /*!
* \return a range that covers the IntSet
*/
Range GetCoverRange() const;
/*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
...@@ -37,7 +41,12 @@ class IntSet : public NodeRef { ...@@ -37,7 +41,12 @@ class IntSet : public NodeRef {
* \param dom The domain to be created. * \param dom The domain to be created.
* \return create integer set from existing domain * \return create integer set from existing domain
*/ */
static IntSet make(Range dom); static IntSet make_range(Range dom);
/*!
* \param point
* \return create integer set that only contains one point
*/
static IntSet make_point(Expr point);
/*! /*!
* \return create integer set that represents everything * \return create integer set that represents everything
*/ */
...@@ -52,8 +61,8 @@ class IntSet : public NodeRef { ...@@ -52,8 +61,8 @@ class IntSet : public NodeRef {
* \param dom_map The domain of each variable. * \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e. * \return An integer set that can cover all the possible values of e.
*/ */
IntSet Eval(Expr e, IntSet EvalSet(Expr e,
const std::unordered_map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
/*! /*!
* \brief Conditional upward message passing. * \brief Conditional upward message passing.
* *
......
...@@ -4,16 +4,32 @@ def test_bound_inference(): ...@@ -4,16 +4,32 @@ def test_bound_inference():
m = tvm.Var('m') m = tvm.Var('m')
l = tvm.Var('l') l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j]) A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op) sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op) sA2 = tvm.Schedule(A2.op)
xo, xi = sA1.split(A1.op.dim_var[0], factor=8) xo, xi = sA2.split(A2.op.dim_var[0], 8)
sA2.compute_at(sA1, xi) sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
print(bounds) print(bounds[A1.op.dim_var[0]])
print(bounds[A1.op.dim_var[1]])
def test_create_read_graph():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
g = tvm.schedule.CreateReadGraph(A2.op)
assert g[A2.op][0] == A1
assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder(A2.op, g)
assert(post_order[0] == A1.op)
assert(post_order[1] == A2.op)
if __name__ == "__main__": if __name__ == "__main__":
test_bound_inference() test_bound_inference()
test_create_read_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment