Commit 88377988 by Tianqi Chen Committed by GitHub

[PASS] Canonical form simplify (#34)

parent 2bcf3f2c
...@@ -63,6 +63,13 @@ bool HasSideEffect(const Expr& e); ...@@ -63,6 +63,13 @@ bool HasSideEffect(const Expr& e);
Stmt ConvertSSA(Stmt stmt); Stmt ConvertSSA(Stmt stmt);
/*! /*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value. * \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted * \param stmt The source statement to be substituted
* \param value_map The map of new values. * \param value_map The map of new values.
......
...@@ -17,7 +17,8 @@ def build(sch, ...@@ -17,7 +17,8 @@ def build(sch,
target, target,
name="default_function", name="default_function",
binds=None, binds=None,
record_codes=None): record_codes=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture. """Build a function with arguments as signiture.
Parameters Parameters
...@@ -38,6 +39,9 @@ def build(sch, ...@@ -38,6 +39,9 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor. Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument. By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int
Maximum step to perform automatic unrolling
Returns Returns
------- -------
f : Function, or pair of functions f : Function, or pair of functions
...@@ -64,6 +68,8 @@ def build(sch, ...@@ -64,6 +68,8 @@ def build(sch,
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = ir_pass.SplitHostDevice(fapi) fsplits = ir_pass.SplitHostDevice(fapi)
......
...@@ -59,6 +59,7 @@ TVM_REGISTER_API(_pass_PostOrderVisit) ...@@ -59,6 +59,7 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(UnrollLoop);
......
/*!
* Copyright (c) 2017 by Contributors
* \file canonical.h
* \brief Internal canonicalized expression simplification engine.
*/
#ifndef TVM_ARITHMETIC_CANONICAL_H_
#define TVM_ARITHMETIC_CANONICAL_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace tvm {
namespace arith {
/*!
* \brief A stateful CanonicalEngine over SSA.
*
* Simplify and CSE with canonicalization expressions.
* Each call's result will get cached, so next call will
* simply return the cached result.
*/
class Canonical {
public:
/*! \brief constructor */
Canonical();
/*!
* \brief simplify expression e.
* \param expr The expression to be simplified.
*/
Expr Simplify(Expr expr);
/*!
* \brief simplify stmt.
* \param stmt The stmt to be simplified.
*/
Stmt Simplify(Stmt expr);
/*!
* \brief Set range and level variable
* \param v The variable
* \param r The range of the variable, can be undefined.
* \param level The scope level of the variable,
* affect the order of formula in communicative ops.
*/
void SetRange(Var v, Range r, int level);
class Internal;
private:
// Internal pointer
std::shared_ptr<Internal> ptr_;
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CANONICAL_H_
...@@ -94,6 +94,11 @@ bool IntSet::is_single_point() const { ...@@ -94,6 +94,11 @@ bool IntSet::is_single_point() const {
return (s_int && s_int->i.is_single_point()); return (s_int && s_int->i.is_single_point());
} }
bool IntSet::can_prove_positive() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
}
Expr IntSet::point_value() const { Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>(); const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point()); CHECK(s_int && s_int->i.is_single_point());
...@@ -358,6 +363,9 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) { ...@@ -358,6 +363,9 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
// Evaluator to evalute the epxression. // Evaluator to evalute the epxression.
class IntSetEvaluator { class IntSetEvaluator {
public: public:
explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
: dom_map(dom_map) {}
inline IntSet Eval(Expr expr) { inline IntSet Eval(Expr expr) {
static const FType& f = vtable(); static const FType& f = vtable();
if (f.can_dispatch(expr)) { if (f.can_dispatch(expr)) {
...@@ -373,7 +381,7 @@ class IntSetEvaluator { ...@@ -373,7 +381,7 @@ class IntSetEvaluator {
static FType inst; return inst; static FType inst; return inst;
} }
std::unordered_map<const Variable*, IntSet> dom_map; const std::unordered_map<const Variable*, IntSet>& dom_map;
}; };
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) { inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) {
...@@ -424,21 +432,29 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) ...@@ -424,21 +432,29 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<And>(Binary<And>) .set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>); .set_dispatch<Or>(Binary<Or>);
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e);
}
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map) { const Map<IterVar, IntSet>& dom_map) {
IntSetEvaluator m; std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) { for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second; dmap[kv.first->var.as<Variable>()] = kv.second;
} }
IntSetEvaluator m(dmap);
return m.Eval(e); return m.Eval(e);
} }
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) { const Map<IterVar, IntSet>& dom_map) {
IntSetEvaluator m; std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) { for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second; dmap[kv.first->var.as<Variable>()] = kv.second;
} }
IntSetEvaluator m(dmap);
IntSet min_set = m.Eval(r->min); IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval(); IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i; const Interval& ei = ext_set.as<IntervalSet>()->i;
......
...@@ -44,6 +44,8 @@ class IntSet : public NodeRef { ...@@ -44,6 +44,8 @@ class IntSet : public NodeRef {
bool is_everything() const; bool is_everything() const;
/*! \return Whether the set is a single point */ /*! \return Whether the set is a single point */
bool is_single_point() const; bool is_single_point() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
/*! /*!
* \brief The single point value, call only if is_single_point is true * \brief The single point value, call only if is_single_point is true
* \return The point value. * \return The point value.
...@@ -88,6 +90,8 @@ struct IntSetNode : public Node { ...@@ -88,6 +90,8 @@ struct IntSetNode : public Node {
*/ */
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! /*!
* \brief Find an symbolic integer set that contains is union over * \brief Find an symbolic integer set that contains is union over
......
...@@ -45,7 +45,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) { ...@@ -45,7 +45,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
std::ostringstream os; std::ostringstream os;
os << "typedef int int32_t;\n" os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n"; << "typedef unsigned unt32_t;\n";
bool output_ssa = true; bool output_ssa = false;
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
os << CodeGenCUDA().Compile(f, output_ssa); os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n'; os << '\n';
......
...@@ -57,7 +57,7 @@ MakeOpenCL(Array<LoweredFunc> funcs) { ...@@ -57,7 +57,7 @@ MakeOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os; std::ostringstream os;
os << "typedef int int32_t;\n" os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n"; << "typedef unsigned unt32_t;\n";
bool output_ssa = true; bool output_ssa = false;
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
os << CodeGenOpenCL().Compile(f, output_ssa); os << CodeGenOpenCL().Compile(f, output_ssa);
os << '\n'; os << '\n';
......
...@@ -3,9 +3,9 @@ import numpy as np ...@@ -3,9 +3,9 @@ import numpy as np
def test_gemm(): def test_gemm():
# graph # graph
nn = 1235 nn = 1024
n = tvm.Var('n') n = tvm.Var('n')
#n = tvm.convert(nn) n = tvm.convert(nn)
m = n m = n
l = n l = n
A = tvm.placeholder((n, l), name='A') A = tvm.placeholder((n, l), name='A')
...@@ -52,12 +52,14 @@ def test_gemm(): ...@@ -52,12 +52,14 @@ def test_gemm():
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y) _, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y)
_, xi = s[BB].split(xi, outer=thread_x) _, xi = s[BB].split(xi, outer=thread_x)
max_auto_unroll_step = 0
# lowering test # lowering test
s.normalize() s.normalize()
def check_device(target): def check_device(target):
codes = [] codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes) f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step)
for c in codes[1:]: for c in codes[1:]:
print(c) print(c)
if target == "cuda": if target == "cuda":
......
import tvm
import numpy
def test_simplify():
"""Not yet working, mock design"""
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 2, n, 0, 0,
tvm.make.For(j, 0, n, 0, 0,
tvm.make.IfThenElse(
tvm.make.LT(i + 2, n),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i + 4) + 1,
(j + 1) * 4 - 4 * j + i),
None)))
print(stmt)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
print(stmt)
if __name__ == "__main__":
test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment