Commit 9fad94cc by Sergei Grechanik Committed by Tianqi Chen

[ARITH][BOUND] Fix bound inference to avoid allocating too much (#3526)

* [TVM] Fix bound inference to avoid allocating too much

* [ARITH][BOUND] Pass analyzer to PropBoundToInputs
parent 75892d2b
......@@ -100,6 +100,7 @@ class OperationNode : public ir::FunctionBaseNode {
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* \param analyzer The analyzer to be used in the function.
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
* \param out_dom_map The output domain.
* The function is only asked to fill the bounds for Tensors that
......@@ -107,6 +108,7 @@ class OperationNode : public ir::FunctionBaseNode {
*/
virtual void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
......@@ -170,6 +172,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
......@@ -247,6 +250,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
......@@ -299,6 +303,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
......@@ -373,6 +378,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
......@@ -439,6 +445,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
......@@ -506,6 +513,7 @@ class HybridOpNode : public OperationNode {
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
......
......@@ -33,6 +33,7 @@
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/int_set.h"
namespace tvm {
......@@ -209,17 +210,41 @@ Operation ComputeOpNode::ReplaceInputs(
void ComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
TensorDom& dom = out_dom_map->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
dom.data[i].push_back(EvalSet(call->args[i], dom_map));
// We assume that the value of the argument cannot be out of bounds (otherwise it is
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map);
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
Expr shape_i_min_value = make_zero(t->shape[i].type());
Expr shape_i_max_value = t->shape[i] - 1;
Expr min_value = arg_interval->min_value;
Expr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
} else {
dom.data[i].push_back(arg_intset);
}
}
}
}
......
......@@ -111,6 +111,7 @@ Operation ExternOpNode::ReplaceInputs(
void ExternOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
......
......@@ -108,6 +108,7 @@ Operation HybridOpNode::ReplaceInputs(
void HybridOpNode::PropBoundToInputs(
const Operation &self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
......
......@@ -78,6 +78,7 @@ Operation PlaceholderOpNode::ReplaceInputs(
void PlaceholderOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
}
......
......@@ -175,6 +175,7 @@ Operation ScanOpNode::ReplaceInputs(
void ScanOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
......
......@@ -110,6 +110,7 @@ Operation TensorComputeOpNode::ReplaceInputs(
void TensorComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (size_t i = 0; i < this->inputs.size(); ++i) {
......
......@@ -85,17 +85,20 @@ size_t InferTensorizeRegion(
// Get domains if inputs
std::unordered_map<Tensor, TensorDom> in_dom;
std::unordered_map<const Variable*, IntSet> temp_dmap;
arith::Analyzer analyzer;
Array<Tensor> inputs = self->InputTensors();
for (Tensor t : inputs) {
in_dom.emplace(t, TensorDom(t.ndim()));
}
for (IterVar iv : self->root_iter_vars()) {
IntSet iset = up_state.at(iv);
(*out_dom)[iv] = iset.cover_range(dom_map.at(iv));
Range iv_range = iset.cover_range(dom_map.at(iv));
(*out_dom)[iv] = iv_range;
analyzer.Bind(iv->var, iv_range);
temp_dmap[iv->var.get()] = iset;
}
// Input domains
self->PropBoundToInputs(stage->op, temp_dmap, &in_dom);
self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
Range none;
for (const auto& kv : in_dom) {
Array<Range> vec;
......
......@@ -191,6 +191,7 @@ void InferRootBound(const Stage& stage,
PassUpDomain(op_stage, *rmap, &up_state);
// Relax if needed.
std::unordered_map<const Variable*, IntSet> dom_map;
arith::Analyzer analyzer;
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
......@@ -203,8 +204,9 @@ void InferRootBound(const Stage& stage,
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
analyzer.Bind(iv->var, r);
}
op->PropBoundToInputs(op, dom_map, &tmap);
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, tmap, rmap);
}
......
......@@ -306,6 +306,27 @@ def test_bound_tensor_compute_op():
assert isinstance(bounds, tvm.container.Map)
assert(bounds[B.op.axis[0]].extent.value == 10)
def test_bound_simplification_failure():
# Check that the bounds are not expanded
A = tvm.compute((2,), lambda j: j, "A")
def _check(B, A=A):
s = tvm.create_schedule(B.op)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.lower(s, [B, A], simple_mode=True)
if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2
# These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i]))
if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
......@@ -320,3 +341,4 @@ if __name__ == "__main__":
test_gemm_bound()
test_bound_warp()
test_bound_tensor_compute_op()
test_bound_simplification_failure()
......@@ -286,20 +286,6 @@ def test_schedule_cache_relayout4():
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.if_then_else(
tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
......@@ -514,7 +500,6 @@ if __name__ == "__main__":
test_schedule1()
test_schedule2()
test_schedule_cache()
test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment