Commit e42cc112 by Tianqi Chen Committed by GitHub

[PASS] UnrollLoop, isolate arithmetic module. (#32)

parent d89917b6
......@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
......@@ -98,6 +98,13 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
/*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
} // namespace ir
} // namespace tvm
......
......@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< "but request arg" << i;
<< " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}
......
......@@ -70,7 +70,6 @@ def build(sch,
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")
if record_codes is not None:
output_ssa = False
......
......@@ -3,5 +3,6 @@
- api API functionr registration
- lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR.
- arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure
- runtime Minimum runtime related codes.
......@@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
namespace tvm {
......@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
}
});
TVM_REGISTER_API(_pass_PostOrderVisit)
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
f(n);
});
});
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
......@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
......
......@@ -4,14 +4,14 @@
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace tvm {
namespace schedule {
namespace arith {
using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow;
......@@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}
} // namespace schedule
} // namespace arith
} // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file int_set_impl.cc
* Copyright (c) 2017 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
......@@ -10,7 +10,7 @@
#include "./compute_expr.h"
namespace tvm {
namespace schedule {
namespace arith {
using Halide::Internal::Interval;
......@@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
return (s_int && s_int->i.is_single_point());
}
Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
return s_int->i.min;
}
IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything());
}
......@@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
}
// Check if a is created from b.
inline bool MatchRange(const IntSet& a,
const Range& b) {
bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
......@@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return CombineSets<OP>(a, b);
}
// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
MatchRange(outer, dom_map.at(s->outer)) &&
MatchRange(inner, dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*parent = Combine<Add>(
Combine<Add>(
Combine<Mul>(outer, IntSet::single_point(factor)), inner),
IntSet::single_point(parent_min));
}
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (MatchRange(fused, dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
const IntervalSet* fused_int = fused.as<IntervalSet>();
if (fused_int && fused_int->i.is_single_point()) {
Expr value = fused_int->i.min;
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (MatchRange(rebased, dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
}
// Evaluator to evalute the epxression.
class IntSetEvaluator {
public:
......@@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
} // namespace schedule
} // namespace arith
} // namespace tvm
......@@ -3,14 +3,14 @@
* \file int_set.h
* \brief Abstraction for all integer set operations.
*/
#ifndef TVM_SCHEDULE_INT_SET_H_
#define TVM_SCHEDULE_INT_SET_H_
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace tvm {
namespace schedule {
namespace arith {
// internal node container of int set.
class IntSetNode;
......@@ -44,6 +44,18 @@ class IntSet : public NodeRef {
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return Whether the set contains everything */
static IntSet everything();
/*!
......@@ -89,59 +101,6 @@ IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
......@@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
} // namespace schedule
} // namespace arith
} // namespace tvm
#endif // TVM_SCHEDULE_INT_SET_H_
#endif // TVM_ARITHMETIC_INT_SET_H_
......@@ -24,10 +24,24 @@ class IRInline : public IRMutator {
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
CHECK_EQ(args_.size(), op->args.size());
bool has_side_effect = false;
for (size_t i = 0; i < op->args.size(); ++i) {
if (HasSideEffect(op->args[i])) has_side_effect = true;
}
if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
} else {
Map<Var, Expr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
expr = Substitute(
Evaluate::make(expr), vmap).as<Evaluate>()->value;
}
return expr;
} else {
......
......@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std::unordered_map<const Variable*, Expr> smap;
};
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
m.smap[kv.first.get()] = kv.second;
}
return m.Mutate(stmt);
}
......
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic//compute_expr.h"
namespace tvm {
namespace ir {
class LoopUnroller : public IRMutator {
public:
explicit LoopUnroller(int max_auto_step)
: max_auto_step_(max_auto_step) {
}
Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = s;
// constant folding.
Expr extent = ir::Simplify(op->extent);
const IntImm* v1 = extent.as<IntImm>();
const UIntImm* v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
bool allow_unroll = value >= 0 && value <= max_auto_step_;
if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop";
allow_unroll = true;
}
if (allow_unroll) {
using arith::ComputeExpr;
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
}
return this->Mutate(unrolled);
} else {
return IRMutator::Mutate_(op, stmt);
}
}
private:
int max_auto_step_;
};
Stmt UnrollLoop(Stmt stmt, int max_auto_step) {
Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt);
return ConvertSSA(ret);
}
} // namespace ir
} // namespace tvm
......@@ -7,13 +7,15 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./graph.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace schedule {
using namespace arith;
// result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b);
......@@ -70,6 +72,80 @@ void PassDown(const Stage& s,
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
outer.match_range(dom_map.at(s->outer)) &&
inner.match_range(dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*parent = EvalSet(
s->outer->var * factor + s->inner->var + parent_min,
{{s->outer, outer}, {s->inner, inner}});
}
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (fused.match_range(dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
if (fused.is_single_point()) {
Expr value = fused.point_value();
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (rebased.match_range(dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = EvalSet(s->rebased->var + parent_min,
{{s->rebased, rebased}});
}
void PassUp(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
......
......@@ -6,7 +6,6 @@
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <unordered_set>
#include "./int_set.h"
#include "./graph.h"
namespace tvm {
......
......@@ -9,13 +9,13 @@
#include <tvm/schedule_pass.h>
#include "../pass/ir_util.h"
#include "./int_set.h"
#include "../arithmetic/compute_expr.h"
#include "./graph.h"
#include "./compute_expr.h"
namespace tvm {
namespace schedule {
using namespace arith;
using namespace ir;
/*!
......@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
return nest;
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}
Stmt MakeLoop(const Stage& s,
const Map<IterVar, Range>& dom_map,
Stmt provide,
......@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
auto nest = MakeLoopNest(s, dom_map, 0, false,
bound_state, {}, &value_map);
provide = Substitute(provide, value_map);
if (init.defined()) {
// try to find the location to insert the initialization.
......
import tvm
def test_unroll_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, n, 0, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
stmt = tvm.ir_pass.UnrollLoop(stmt, 8)
print(stmt)
if __name__ == "__main__":
test_unroll_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment