Commit 88377988 by Tianqi Chen Committed by GitHub

[PASS] Canonical form simplify (#34)

parent 2bcf3f2c
......@@ -63,6 +63,13 @@ bool HasSideEffect(const Expr& e);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
......
......@@ -17,7 +17,8 @@ def build(sch,
target,
name="default_function",
binds=None,
record_codes=None):
record_codes=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.
Parameters
......@@ -38,6 +39,9 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int
Maximum step to perform automatic unrolling
Returns
-------
f : Function, or pair of functions
......@@ -64,6 +68,8 @@ def build(sch,
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = ir_pass.SplitHostDevice(fapi)
......
......@@ -59,6 +59,7 @@ TVM_REGISTER_API(_pass_PostOrderVisit)
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
......
/*!
* Copyright (c) 2017 by Contributors
* \file canonical.h
* \brief Internal canonicalized expression simplification engine.
*/
#ifndef TVM_ARITHMETIC_CANONICAL_H_
#define TVM_ARITHMETIC_CANONICAL_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace tvm {
namespace arith {
/*!
* \brief A stateful CanonicalEngine over SSA.
*
* Simplify and CSE with canonicalization expressions.
* Each call's result will get cached, so next call will
* simply return the cached result.
*/
class Canonical {
public:
/*! \brief constructor */
Canonical();
/*!
* \brief simplify expression e.
* \param expr The expression to be simplified.
*/
Expr Simplify(Expr expr);
/*!
* \brief simplify stmt.
* \param stmt The stmt to be simplified.
*/
Stmt Simplify(Stmt expr);
/*!
* \brief Set range and level variable
* \param v The variable
* \param r The range of the variable, can be undefined.
* \param level The scope level of the variable,
* affect the order of formula in communicative ops.
*/
void SetRange(Var v, Range r, int level);
class Internal;
private:
// Internal pointer
std::shared_ptr<Internal> ptr_;
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CANONICAL_H_
......@@ -94,6 +94,11 @@ bool IntSet::is_single_point() const {
return (s_int && s_int->i.is_single_point());
}
bool IntSet::can_prove_positive() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
}
Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
......@@ -358,6 +363,9 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
// Evaluator to evalute the epxression.
class IntSetEvaluator {
public:
explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
: dom_map(dom_map) {}
inline IntSet Eval(Expr expr) {
static const FType& f = vtable();
if (f.can_dispatch(expr)) {
......@@ -373,7 +381,7 @@ class IntSetEvaluator {
static FType inst; return inst;
}
std::unordered_map<const Variable*, IntSet> dom_map;
const std::unordered_map<const Variable*, IntSet>& dom_map;
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) {
......@@ -424,21 +432,29 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e);
}
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map) {
IntSetEvaluator m;
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
return m.Eval(e);
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
IntSetEvaluator m;
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i;
......
......@@ -44,6 +44,8 @@ class IntSet : public NodeRef {
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
......@@ -88,6 +90,8 @@ struct IntSetNode : public Node {
*/
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
......
......@@ -45,7 +45,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
std::ostringstream os;
os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n";
bool output_ssa = true;
bool output_ssa = false;
for (LoweredFunc f : funcs) {
os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n';
......
......@@ -57,7 +57,7 @@ MakeOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os;
os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n";
bool output_ssa = true;
bool output_ssa = false;
for (LoweredFunc f : funcs) {
os << CodeGenOpenCL().Compile(f, output_ssa);
os << '\n';
......
......@@ -3,9 +3,9 @@ import numpy as np
def test_gemm():
# graph
nn = 1235
nn = 1024
n = tvm.Var('n')
#n = tvm.convert(nn)
n = tvm.convert(nn)
m = n
l = n
A = tvm.placeholder((n, l), name='A')
......@@ -52,12 +52,14 @@ def test_gemm():
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y)
_, xi = s[BB].split(xi, outer=thread_x)
max_auto_unroll_step = 0
# lowering test
s.normalize()
def check_device(target):
codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes)
f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step)
for c in codes[1:]:
print(c)
if target == "cuda":
......
import tvm
import numpy
def test_simplify():
"""Not yet working, mock design"""
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 2, n, 0, 0,
tvm.make.For(j, 0, n, 0, 0,
tvm.make.IfThenElse(
tvm.make.LT(i + 2, n),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i + 4) + 1,
(j + 1) * 4 - 4 * j + i),
None)))
print(stmt)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
print(stmt)
if __name__ == "__main__":
test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment