Commit 5198c100 by Ziheng Jiang Committed by Tianqi Chen

[ARITH] DeduceBound (#40)

* [PYTHON/API] Add compare and logic build-in op for Expr

* remove 'and', 'or'

* add deducer

* [WIP] bound_deducer.cc

* move IntervalSet and StrideSet into int_set_internal.h

* add multiple failure for VariablePathFinder, add EvalSign

* consider round in deduce, add success flag

* remove Visit_(Div)

* add comment, update HalideIR

* expose intset to python

* check the sign of every expr

* set return type as ExprSignType

* fine tune

* add min & max python api for interval set

* support for conditional expr

* refactor test

* add checker for BoundDeducer

* add python check test

* fix

* fix

* change range to interval; remove converter

* remove converter declaration

* remove int_set_internal.h
parent d114dfc9
Subproject commit 642ae50ac749c91c04483db04500163304d4334e
Subproject commit e68ae61cd541ac29efc9fafe2ad061479bcaa9c9
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import register_node
from . import tensor
from . import arith
from . import expr
from . import stmt
from . import make
......
......@@ -244,6 +244,7 @@ def _init_api_functions(root_namespace):
module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_arith_": sys.modules["%s.arith" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
......
# pylint: disable=protected-access, no-member
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
@register_node
class IntSet(NodeBase):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _api_internal._IntSetIsNothing(self)
def is_everything(self):
"""Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self)
@register_node
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)
def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)
@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to arith
* \file api_arith.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"
namespace tvm {
namespace arith {
TVM_REGISTER_API(_arith_intset_single_point)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});
TVM_REGISTER_API(_arith_intset_interval)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});
TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
});
TVM_REGISTER_API(_IntervalSetGetMin)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});
TVM_REGISTER_API(_IntervalSetGetMax)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});
TVM_REGISTER_API(_IntSetIsNothing)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});
TVM_REGISTER_API(_IntSetIsEverything)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});
} // namespace arith
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include <unordered_map>
#include "./int_set.h"
namespace tvm {
namespace arith {
using namespace ir;
using Halide::Internal::Interval;
// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
public:
explicit VariablePathFinder(Var target) : target_(target) {}
void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node);
if (!found_) path_.pop_back();
}
std::vector<const Node*> path_;
private:
bool found_{false};
Var target_;
std::unordered_set<const Node*> visited_;
};
// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.path_;
}
class BoundDeduceIntputChecker;
// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Var target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_(target), expr_(expr), dom_map_(dom_map) {}
bool Init();
void Deduce();
void Visit(const NodeRef& e) final {
if (!success) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
return;
}
}
void Visit_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}
void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
} else {
result -= op->a;
result = - result;
is_greater = !is_greater;
}
Visit(left ? op->a : op->b);
}
void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
SignType sign;
if (operand.type().is_uint()) {
sign = kPositive;
} else {
sign = expr_map_[operand].sign_type();
}
if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
return;
}
// always use relax bound
result = result / operand + (is_greater ? 1 : -1);
Visit(left ? op->a : op->b);
}
Expr result;
bool is_greater{true};
bool is_equal{true};
bool success{true};
private:
Var target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_;
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
};
class BoundDeduceInputChecker: public IRVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
Visit(deducer_->expr_);
return target_count == 1;
}
void Visit(const NodeRef& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
}
private:
BoundDeducer* deducer_;
size_t target_count{0};
};
bool BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;
if (const LT* op = expr_.as<LT>()) {
is_greater = false;
is_equal = false;
expr_ = op->a;
result = op->b;
} else if (const LE* op = expr_.as<LE>()) {
is_greater = false;
is_equal = true;
expr_ = op->a;
result = op->b;
} else if (const GT* op = expr_.as<GT>()) {
is_greater = true;
is_equal = false;
expr_ = op->a;
result = op->b;
} else if (const GE* op = expr_.as<GE>()) {
is_greater = true;
is_equal = true;
expr_ = op->a;
result = op->b;
} else {
success = false;
}
return success;
}
void BoundDeducer::Deduce() {
Init();
if (!success) return;
// get the path
path_ = GetPath(target_, expr_);
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_);
Visit(expr_);
}
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Var v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
}
BoundDeducer d(v, e, dmap);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
if (d.is_greater) {
min = d.is_equal ? d.result : d.result + 1;
} else {
max = d.is_equal ? d.result : d.result - 1;
}
return IntSet::interval(min, max);
}
} // namespace arith
} // namespace tvm
......@@ -6,55 +6,17 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <pass/Interval.h>
#include <unordered_map>
#include "./int_set.h"
#include "./compute_expr.h"
#include "./int_set_internal.h"
namespace tvm {
namespace arith {
using Halide::Internal::Interval;
using namespace ir;
/*! \brief Set of continuous interval */
struct IntervalSet : public IntSetNode {
/*! \brief the internal interval*/
Interval i;
static IntSet make(Interval i) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
n->i = i;
return IntSet(n);
}
static IntSet make(Expr min, Expr max) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
n->i.min = min;
n->i.max = max;
return IntSet(n);
}
static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet);
};
/*!
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
struct StrideSet : public IntSetNode {
/*! \brief the base inetrval */
Interval base;
/*! \brief additional extents in positive number */
Array<Expr> extents;
/*! \brief additional strides in positive number */
Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet);
};
inline IntSet IntSet::cover_interval() const {
if ((*this).as<IntervalSet>()) return *this;
const StrideSet* s = (*this).as<StrideSet>();
......@@ -84,6 +46,23 @@ Range IntSet::cover_range(Range max_range) const {
return max_range;
}
Expr IntSet::min() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int);
return s_int->i.min;
}
Expr IntSet::max() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int);
return s_int->i.max;
}
bool IntSet::is_nothing() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_empty());
}
bool IntSet::is_everything() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_everything());
......@@ -99,12 +78,32 @@ bool IntSet::can_prove_positive() const {
return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
}
bool IntSet::can_prove_negative() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
}
SignType IntSet::sign_type() const {
if (can_prove_positive()) {
return kPositive;
} else if (can_prove_negative()) {
return kNegative;
} else if (is_single_point() && is_zero(point_value())) {
return kZero;
} else {
return kUnknown;
}
}
Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
return s_int->i.min;
}
IntSet IntSet::nothing() {
return IntervalSet::make(Interval::nothing());
}
IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything());
}
......@@ -125,6 +124,13 @@ IntSet IntSet::range(Range r) {
return IntervalSet::make(r->min, (r->extent + r->min) - 1);
}
IntSet IntSet::interval(Expr min, Expr max) {
if (min.same_as(max)) {
return IntSet::single_point(min);
}
return IntervalSet::make(min, max);
}
// Check if a is created from b.
bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
......@@ -366,13 +372,13 @@ class IntSetEvaluator {
explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
: dom_map(dom_map) {}
inline IntSet Eval(Expr expr) {
inline virtual IntSet Eval(Expr expr) {
static const FType& f = vtable();
if (f.can_dispatch(expr)) {
return f(expr, expr, this);
} else {
LOG(WARNING) << "cannot evaluate set type " << expr->type_key();
return IntSet::everything();
return IntSet::nothing();
}
}
......@@ -384,7 +390,7 @@ class IntSetEvaluator {
const std::unordered_map<const Variable*, IntSet>& dom_map;
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) {
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) {
return IntSet::single_point(e);
}
......@@ -411,8 +417,7 @@ inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) {
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntSet::single_point(e);
}
IntSet r = Combine<T>(a, b);
return r;
return Combine<T>(a, b);
}
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
......@@ -457,6 +462,27 @@ IntSet EvalSet(Range r,
return Combine<Add>(min_set, ext_set);
}
class SubExprIntSetEvaluator : public IntSetEvaluator {
public:
explicit SubExprIntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
: IntSetEvaluator(dom_map) {}
inline IntSet Eval(Expr expr) override {
IntSet ret = IntSetEvaluator::Eval(expr);
expr_map[expr] = ret;
return ret;
}
ExprIntSetMap expr_map;
};
ExprIntSetMap EvalSetForEachSubExpr(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
SubExprIntSetEvaluator m(dom_map);
m.Eval(e);
return m.expr_map;
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
......@@ -468,7 +494,7 @@ IntSet EvalSet(Range r,
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
p->stream << "interval-set["
p->stream << "interval-set"
<< "[" << op->i.min << ", "
<< op->i.max << ']';
});
......
......@@ -12,6 +12,13 @@
namespace tvm {
namespace arith {
enum SignType {
kPositive,
kNegative,
kZero,
kUnknown
};
// internal node container of int set.
class IntSetNode;
......@@ -40,12 +47,22 @@ class IntSet : public NodeRef {
* \return The covering interval set.
*/
IntSet cover_interval() const;
/*! \return Lower bound of the set */
Expr min() const;
/*! \return upper bound of the set */
Expr max() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
/*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const;
/*! \return The sign of the elements in the integer set */
SignType sign_type() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
......@@ -58,7 +75,9 @@ class IntSet : public NodeRef {
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return Whether the set contains everything */
/*! \return The set contains nothing */
static IntSet nothing();
/*! \return The set contains everything */
static IntSet everything();
/*!
* \brief construct a point set.
......@@ -72,6 +91,13 @@ class IntSet : public NodeRef {
* \return constructed set.
*/
static IntSet range(Range r);
/*!
* \brief Construct a set representing a interval.
* \param min The minimum value of the interval.
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(Expr min, Expr max);
};
/*!
......@@ -80,6 +106,9 @@ class IntSet : public NodeRef {
struct IntSetNode : public Node {
};
using ExprIntSetMap = std::unordered_map<Expr, IntSet,
Halide::ExprHash, Halide::ExprEqual>;
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
......@@ -107,6 +136,18 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(Expr r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
......@@ -119,6 +160,19 @@ inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
/*!
* \brief Deduce the bound of the target variable in a expression,
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet DeduceBound(Var v, Expr cond,
const Map<Var, IntSet>& dom_map);
} // namespace arith
} // namespace tvm
......
/*!
* Copyright (c) 2017 by Contributors
* \file int_set_internal.h
* \brief Implementations of integer set
*/
#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./int_set.h"
namespace tvm {
namespace arith {
using Halide::Internal::Interval;
/*! \brief Set of continuous interval */
struct IntervalSet : public IntSetNode {
/*! \brief the internal interval*/
Interval i;
static IntSet make(Interval i) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
n->i = i;
return IntSet(n);
}
static IntSet make(Expr min, Expr max) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
n->i.min = min;
n->i.max = max;
return IntSet(n);
}
static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet);
};
/*!
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
struct StrideSet : public IntSetNode {
/*! \brief the base inetrval */
Interval base;
/*! \brief additional extents in positive number */
Array<Expr> extents;
/*! \brief additional strides in positive number */
Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet);
};
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
import tvm
def test_basic():
s = tvm.arith.intset_interval(2, 3)
assert s.min().value == 2
assert s.max().value == 3
def test_deduce():
a = tvm.Var('a')
b = tvm.Var('b')
c = tvm.Var('c')
d = tvm.Var('d')
b_s = tvm.arith.intset_interval(2, 3)
c_s = tvm.arith.intset_interval(10, 15)
d_s = tvm.arith.intset_interval(-3, -1)
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s})
ans0 = (d-c)/(-b)+(-1)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s})
ans1 = (c-b)/4+(-2)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
def test_check():
a = tvm.Var('a')
b = tvm.Var('b')
c = tvm.Var('c')
d = tvm.Var('d')
b_s = tvm.arith.intset_interval(2, 3)
c_s = tvm.arith.intset_interval(5, 7)
d_s = tvm.arith.intset_interval(-3, -1)
# no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s})
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s})
assert res1.is_nothing()
# multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s})
assert res1.is_nothing()
if __name__ == "__main__":
test_basic()
test_deduce()
test_check()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment