Commit e42cc112 by Tianqi Chen Committed by GitHub

[PASS] UnrollLoop, isolate arithmetic module. (#32)

parent d89917b6
...@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt); ...@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values. * \param value_map The map of new values.
* \return The converted form. * \return The converted form.
*/ */
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map); Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
/*! /*!
* \brief inline all calls of f in stmt. * \brief inline all calls of f in stmt.
...@@ -98,6 +98,13 @@ Stmt StorageFlatten(Stmt stmt, ...@@ -98,6 +98,13 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer); Map<Tensor, Buffer> extern_buffer);
/*! /*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
...@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func); ...@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/ */
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const { ...@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args) CHECK_LT(i, num_args)
<< "not enough argument passed, " << "not enough argument passed, "
<< num_args << " passed" << num_args << " passed"
<< "but request arg" << i; << " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]); return TVMArgValue(values[i], type_codes[i]);
} }
......
...@@ -70,7 +70,6 @@ def build(sch, ...@@ -70,7 +70,6 @@ def build(sch,
fsplits = [x for x in fsplits] fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)): for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared") fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")
if record_codes is not None: if record_codes is not None:
output_ssa = False output_ssa = False
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
- api API functionr registration - api API functionr registration
- lang The definition of DSL related data structure - lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR. - schedule The operations on the schedule graph before converting to IR.
- arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure - pass The optimization pass on the IR structure
- runtime Minimum runtime related codes. - runtime Minimum runtime related codes.
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
namespace tvm { namespace tvm {
...@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal) ...@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
} }
}); });
TVM_REGISTER_API(_pass_PostOrderVisit)
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
f(n);
});
});
// make from two arguments // make from two arguments
#define REGISTER_PASS1(PassName) \ #define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API(_pass_## PassName) \
...@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA); ...@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI); REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
* \brief Utility integer expression with quick eager simplification. * \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly. * This is weaker than Simplify but can be done Eagerly.
*/ */
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_ #ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <pass/Interval.h> #include <pass/Interval.h>
namespace tvm { namespace tvm {
namespace schedule { namespace arith {
using Halide::Internal::add_would_overflow; using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow; using Halide::Internal::sub_would_overflow;
...@@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) { ...@@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b); return Halide::Internal::Interval::make_min(a, b);
} }
} // namespace schedule } // namespace arith
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_ #endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2017 by Contributors
* \file int_set_impl.cc * \file int_set.cc
* \brief The integer set functions * \brief The integer set functions
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "./compute_expr.h" #include "./compute_expr.h"
namespace tvm { namespace tvm {
namespace schedule { namespace arith {
using Halide::Internal::Interval; using Halide::Internal::Interval;
...@@ -94,6 +94,12 @@ bool IntSet::is_single_point() const { ...@@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
return (s_int && s_int->i.is_single_point()); return (s_int && s_int->i.is_single_point());
} }
Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
return s_int->i.min;
}
IntSet IntSet::everything() { IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything()); return IntervalSet::make(Interval::everything());
} }
...@@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) { ...@@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
} }
// Check if a is created from b. // Check if a is created from b.
inline bool MatchRange(const IntSet& a, bool IntSet::match_range(const Range& b) const {
const Range& b) { const IntSet& a = *this;
const IntervalSet* a_int = a.as<IntervalSet>(); const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false; if (!a_int) return false;
const Interval& i = a_int->i; const Interval& i = a_int->i;
...@@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) { ...@@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return CombineSets<OP>(a, b); return CombineSets<OP>(a, b);
} }
// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
MatchRange(outer, dom_map.at(s->outer)) &&
MatchRange(inner, dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*parent = Combine<Add>(
Combine<Add>(
Combine<Mul>(outer, IntSet::single_point(factor)), inner),
IntSet::single_point(parent_min));
}
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (MatchRange(fused, dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
const IntervalSet* fused_int = fused.as<IntervalSet>();
if (fused_int && fused_int->i.is_single_point()) {
Expr value = fused_int->i.min;
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (MatchRange(rebased, dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
}
// Evaluator to evalute the epxression. // Evaluator to evalute the epxression.
class IntSetEvaluator { class IntSetEvaluator {
public: public:
...@@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
} // namespace schedule } // namespace arith
} // namespace tvm } // namespace tvm
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
* \file int_set.h * \file int_set.h
* \brief Abstraction for all integer set operations. * \brief Abstraction for all integer set operations.
*/ */
#ifndef TVM_SCHEDULE_INT_SET_H_ #ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_SCHEDULE_INT_SET_H_ #define TVM_ARITHMETIC_INT_SET_H_
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
namespace tvm { namespace tvm {
namespace schedule { namespace arith {
// internal node container of int set. // internal node container of int set.
class IntSetNode; class IntSetNode;
...@@ -44,6 +44,18 @@ class IntSet : public NodeRef { ...@@ -44,6 +44,18 @@ class IntSet : public NodeRef {
bool is_everything() const; bool is_everything() const;
/*! \return Whether the set is a single point */ /*! \return Whether the set is a single point */
bool is_single_point() const; bool is_single_point() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return Whether the set contains everything */ /*! \return Whether the set contains everything */
static IntSet everything(); static IntSet everything();
/*! /*!
...@@ -89,59 +101,6 @@ IntSet EvalSet(Range r, ...@@ -89,59 +101,6 @@ IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
/*! /*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*!
* \brief Create an union set of all sets * \brief Create an union set of all sets
* \param sets The sets to be unioned * \param sets The sets to be unioned
* \return the set after union * \return the set after union
...@@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const { ...@@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get()); return static_cast<const IntSetNode*>(node_.get());
} }
} // namespace schedule } // namespace arith
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_INT_SET_H_ #endif // TVM_ARITHMETIC_INT_SET_H_
...@@ -24,11 +24,25 @@ class IRInline : public IRMutator { ...@@ -24,11 +24,25 @@ class IRInline : public IRMutator {
if (op->func == f_) { if (op->func == f_) {
CHECK_EQ(op->value_index, 0); CHECK_EQ(op->value_index, 0);
Expr expr = body_; Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size()) CHECK_EQ(args_.size(), op->args.size());
<< op->args.size() << " vs " << args_.size();
bool has_side_effect = false;
for (size_t i = 0; i < op->args.size(); ++i) {
if (HasSideEffect(op->args[i])) has_side_effect = true;
}
if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) { for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr); expr = Let::make(args_[i], op->args[i], expr);
} }
} else {
Map<Var, Expr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
expr = Substitute(
Evaluate::make(expr), vmap).as<Evaluate>()->value;
}
return expr; return expr;
} else { } else {
return e; return e;
......
...@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator { ...@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std::unordered_map<const Variable*, Expr> smap; std::unordered_map<const Variable*, Expr> smap;
}; };
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) { Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
IRSubstitue m; IRSubstitue m;
for (auto kv : value_map) { for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second; m.smap[kv.first.get()] = kv.second;
} }
return m.Mutate(stmt); return m.Mutate(stmt);
} }
......
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic//compute_expr.h"
namespace tvm {
namespace ir {
class LoopUnroller : public IRMutator {
public:
explicit LoopUnroller(int max_auto_step)
: max_auto_step_(max_auto_step) {
}
Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = s;
// constant folding.
Expr extent = ir::Simplify(op->extent);
const IntImm* v1 = extent.as<IntImm>();
const UIntImm* v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
bool allow_unroll = value >= 0 && value <= max_auto_step_;
if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop";
allow_unroll = true;
}
if (allow_unroll) {
using arith::ComputeExpr;
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
}
return this->Mutate(unrolled);
} else {
return IRMutator::Mutate_(op, stmt);
}
}
private:
int max_auto_step_;
};
Stmt UnrollLoop(Stmt stmt, int max_auto_step) {
Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt);
return ConvertSSA(ret);
}
} // namespace ir
} // namespace tvm
...@@ -7,13 +7,15 @@ ...@@ -7,13 +7,15 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./graph.h" #include "./graph.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
using namespace arith;
// result = ceil((a / b)), both a and b are positive integer // result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) { inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b); return ir::Simplify((a + b - 1) / b);
...@@ -70,6 +72,80 @@ void PassDown(const Stage& s, ...@@ -70,6 +72,80 @@ void PassDown(const Stage& s,
// pass the integer set on each leave loop up to the root // pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar. // dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction. // dom_map can be used to get cached result in reverse construction.
// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
outer.match_range(dom_map.at(s->outer)) &&
inner.match_range(dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*parent = EvalSet(
s->outer->var * factor + s->inner->var + parent_min,
{{s->outer, outer}, {s->inner, inner}});
}
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (fused.match_range(dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
if (fused.is_single_point()) {
Expr value = fused.point_value();
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (rebased.match_range(dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = EvalSet(s->rebased->var + parent_min,
{{s->rebased, rebased}});
}
void PassUp(const Stage& s, void PassUp(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) { std::unordered_map<IterVar, IntSet>* p_state) {
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <unordered_set> #include <unordered_set>
#include "./int_set.h"
#include "./graph.h" #include "./graph.h"
namespace tvm { namespace tvm {
......
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
#include "./int_set.h" #include "../arithmetic/compute_expr.h"
#include "./graph.h" #include "./graph.h"
#include "./compute_expr.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
using namespace arith;
using namespace ir; using namespace ir;
/*! /*!
...@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch, ...@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
return nest; return nest;
} }
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}
Stmt MakeLoop(const Stage& s, Stmt MakeLoop(const Stage& s,
const Map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
Stmt provide, Stmt provide,
...@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s, ...@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
auto nest = MakeLoopNest(s, dom_map, 0, false, auto nest = MakeLoopNest(s, dom_map, 0, false,
bound_state, {}, &value_map); bound_state, {}, &value_map);
provide = Substitute(provide, value_map); provide = Substitute(provide, value_map);
if (init.defined()) { if (init.defined()) {
// try to find the location to insert the initialization. // try to find the location to insert the initialization.
......
import tvm
def test_unroll_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, n, 0, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
stmt = tvm.ir_pass.UnrollLoop(stmt, 8)
print(stmt)
if __name__ == "__main__":
test_unroll_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment