Commit 3f3bf29d by Tianqi Chen Committed by GitHub

[DOC] Make range related function consistent (#249)

parent 2ab0bfb5
...@@ -27,6 +27,7 @@ using Halide::ExprEqual; ...@@ -27,6 +27,7 @@ using Halide::ExprEqual;
using Halide::Expr; using Halide::Expr;
using Halide::VarExpr; using Halide::VarExpr;
using Halide::IR::RangeNode;
using Halide::IR::FunctionRef; using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode; using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
...@@ -113,7 +114,7 @@ class Range : public Halide::IR::Range { ...@@ -113,7 +114,7 @@ class Range : public Halide::IR::Range {
*/ */
Range(Expr begin, Expr end); Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent); static Range make_by_min_extent(Expr min, Expr extent);
}; };
/*! /*!
......
...@@ -8,4 +8,24 @@ You can use make function to build the IR node. ...@@ -8,4 +8,24 @@ You can use make function to build the IR node.
""" """
from ._ffi.function import _init_api from ._ffi.function import _init_api
def range_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Parameters
----------
min_value : Expr
The minimum value of the range.
extent : Expr
The extent of the range.
Returns
-------
rng : Range
The constructed range.
"""
return _range_by_min_extent(min_value, extent)
_init_api("tvm.make") _init_api("tvm.make")
...@@ -16,6 +16,11 @@ TVM_REGISTER_API("_Var") ...@@ -16,6 +16,11 @@ TVM_REGISTER_API("_Var")
*ret = Variable::make(args[1], args[0]); *ret = Variable::make(args[1], args[0]);
}); });
TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
});
TVM_REGISTER_API("make.For") TVM_REGISTER_API("make.For")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = For::make(args[0], *ret = For::make(args[0],
...@@ -147,6 +152,7 @@ REGISTER_MAKE3(AssertStmt); ...@@ -147,6 +152,7 @@ REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate); REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide); REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block); REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse); REGISTER_MAKE3(IfThenElse);
......
...@@ -309,7 +309,7 @@ class Canonical::Internal : public IRMutator { ...@@ -309,7 +309,7 @@ class Canonical::Internal : public IRMutator {
++level_counter_; ++level_counter_;
Var loop_var(op->loop_var.node_); Var loop_var(op->loop_var.node_);
this->SetRange(loop_var, this->SetRange(loop_var,
Range::make_with_min_extent(op->min, op->extent), Range::make_by_min_extent(op->min, op->extent),
level_counter_); level_counter_);
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
--level_counter_; --level_counter_;
...@@ -324,7 +324,7 @@ class Canonical::Internal : public IRMutator { ...@@ -324,7 +324,7 @@ class Canonical::Internal : public IRMutator {
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_level_.count(iv->var.get())) { if (!var_level_.count(iv->var.get())) {
this->SetRange(iv->var, this->SetRange(iv->var,
Range::make_with_min_extent(0, op->value), Range::make_by_min_extent(0, op->value),
level_counter_); level_counter_);
} }
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
......
...@@ -40,7 +40,7 @@ Range IntSet::cover_range(Range max_range) const { ...@@ -40,7 +40,7 @@ Range IntSet::cover_range(Range max_range) const {
s_int = temp.as<IntervalSet>(); s_int = temp.as<IntervalSet>();
} }
if (s_int->i.is_bounded()) { if (s_int->i.is_bounded()) {
return Range::make_with_min_extent( return Range::make_by_min_extent(
s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min)); s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
} }
return max_range; return max_range;
......
...@@ -22,7 +22,7 @@ Range::Range(Expr begin, Expr end) ...@@ -22,7 +22,7 @@ Range::Range(Expr begin, Expr end)
is_zero(begin) ? end : (end - begin))) { is_zero(begin) ? end : (end - begin))) {
} }
Range Range::make_with_min_extent(Expr min, Expr extent) { Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent)); return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
} }
......
...@@ -93,7 +93,7 @@ void ExternOpNode::PropBoundToInputs( ...@@ -93,7 +93,7 @@ void ExternOpNode::PropBoundToInputs(
TensorDom& dom = it->second; TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) { for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range( dom.data[i].emplace_back(IntSet::range(
Range::make_with_min_extent( Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]))); make_const(t->shape[i].type(), 0), t->shape[i])));
} }
} }
...@@ -116,7 +116,7 @@ Stmt ExternOpNode::BuildRealize( ...@@ -116,7 +116,7 @@ Stmt ExternOpNode::BuildRealize(
Halide::Internal::Region bounds; Halide::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) { for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back( bounds.push_back(
Range::make_with_min_extent( Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])); make_const(t->shape[i].type(), 0), t->shape[i]));
} }
realize_body = ir::Realize::make( realize_body = ir::Realize::make(
......
...@@ -76,7 +76,7 @@ Operation ScanOpNode::make(std::string name, ...@@ -76,7 +76,7 @@ Operation ScanOpNode::make(std::string name,
spatial_name << name << ".out" << i << ".i" << k; spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back( n->spatial_axis_.push_back(
IterVarNode::make( IterVarNode::make(
Range::make_with_min_extent(0, update[i]->shape[k]), Range::make_by_min_extent(0, update[i]->shape[k]),
Var(spatial_name.str()), kOpaque)); Var(spatial_name.str()), kOpaque));
} }
} }
...@@ -104,7 +104,7 @@ Array<Tensor> scan(Array<Tensor> init, ...@@ -104,7 +104,7 @@ Array<Tensor> scan(Array<Tensor> init,
std::string tag) { std::string tag) {
IterVar scan_axis = IterVar scan_axis =
IterVarNode::make( IterVarNode::make(
Range::make_with_min_extent( Range::make_by_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered); Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make( Operation op = ScanOpNode::make(
...@@ -165,7 +165,7 @@ void ScanOpNode::PropBoundToInputs( ...@@ -165,7 +165,7 @@ void ScanOpNode::PropBoundToInputs(
// first dimension, always needed. // first dimension, always needed.
if (init_dom) { if (init_dom) {
init_dom->data[0].push_back(IntSet::range( init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, this->init[i]->shape[0]))); Range::make_by_min_extent(0, this->init[i]->shape[0])));
} }
if (update_dom) { if (update_dom) {
update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get())); update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
...@@ -203,7 +203,7 @@ void ScanOpNode::GatherBound( ...@@ -203,7 +203,7 @@ void ScanOpNode::GatherBound(
CHECK(!out_dom_map->count(this->scan_axis)); CHECK(!out_dom_map->count(this->scan_axis));
Range sdom = this->scan_axis->dom; Range sdom = this->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom); Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent( (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self); Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis. // Update for spatial axis.
...@@ -231,7 +231,7 @@ Stmt ScanOpNode::BuildRealize( ...@@ -231,7 +231,7 @@ Stmt ScanOpNode::BuildRealize(
const Stmt& body) const { const Stmt& body) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
Range sdom = dom_map.at(this->scan_axis); Range sdom = dom_map.at(this->scan_axis);
Range tdom = Range::make_with_min_extent( Range tdom = Range::make_by_min_extent(
0, ir::Simplify(sdom->extent + sdom->min)); 0, ir::Simplify(sdom->extent + sdom->min));
Stmt ret = body; Stmt ret = body;
size_t sp_idx = 0; size_t sp_idx = 0;
......
...@@ -32,7 +32,7 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { ...@@ -32,7 +32,7 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
if (!r->min.same_as(new_min)) changed = true; if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true; if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make( new_dom[i] = IterVarNode::make(
Range::make_with_min_extent(new_min, new_extent), Range::make_by_min_extent(new_min, new_extent),
v->var, v->iter_type, v->thread_tag); v->var, v->iter_type, v->thread_tag);
} }
if (!changed) { if (!changed) {
......
...@@ -171,8 +171,8 @@ class ChannelAccessRewriter : public IRMutator { ...@@ -171,8 +171,8 @@ class ChannelAccessRewriter : public IRMutator {
Channel ch(adv_op->node.node_); Channel ch(adv_op->node.node_);
ChannelAccessBound acc(ch->handle_var.get(), read_access); ChannelAccessBound acc(ch->handle_var.get(), read_access);
IntSet iset = acc.Eval(for_op->body); IntSet iset = acc.Eval(for_op->body);
Range r = iset.cover_range(Range::make_with_min_extent(0, window)); Range r = iset.cover_range(Range::make_by_min_extent(0, window));
r = Range::make_with_min_extent( r = Range::make_by_min_extent(
ir::Simplify(r->min), ir::Simplify(r->extent)); ir::Simplify(r->min), ir::Simplify(r->extent));
if (ExprUseVar(r->extent, var)) return body; if (ExprUseVar(r->extent, var)) return body;
Array<Expr> linear_eq = DetectLinearEquation(r->min, var); Array<Expr> linear_eq = DetectLinearEquation(r->min, var);
......
...@@ -53,14 +53,14 @@ void PassDownDomain(const Stage& stage, ...@@ -53,14 +53,14 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner)); CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent); const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) { if (r->factor.defined()) {
Update(p_state, r->inner, Range::make_with_min_extent(0, r->factor)); Update(p_state, r->inner, Range::make_by_min_extent(0, r->factor));
Update(p_state, r->outer, Update(p_state, r->outer,
Range::make_with_min_extent( Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->factor))); 0, DivCeil(range_parent->extent, r->factor)));
} else { } else {
Update(p_state, r->outer, Range::make_with_min_extent(0, r->nparts)); Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts));
Update(p_state, r->inner, Update(p_state, r->inner,
Range::make_with_min_extent( Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->nparts))); 0, DivCeil(range_parent->extent, r->nparts)));
} }
} else if (const FuseNode* r = rel.as<FuseNode>()) { } else if (const FuseNode* r = rel.as<FuseNode>()) {
...@@ -70,7 +70,7 @@ void PassDownDomain(const Stage& stage, ...@@ -70,7 +70,7 @@ void PassDownDomain(const Stage& stage,
} }
const Range& range_outer = state.at(r->outer); const Range& range_outer = state.at(r->outer);
const Range& range_inner = state.at(r->inner); const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent( state[r->fused] = Range::make_by_min_extent(
0, range_outer->extent * range_inner->extent); 0, range_outer->extent * range_inner->extent);
} else if (const RebaseNode* r = rel.as<RebaseNode>()) { } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
if (!state.count(r->parent)) { if (!state.count(r->parent)) {
...@@ -78,7 +78,7 @@ void PassDownDomain(const Stage& stage, ...@@ -78,7 +78,7 @@ void PassDownDomain(const Stage& stage,
continue; continue;
} }
Update(p_state, r->rebased, Update(p_state, r->rebased,
Range::make_with_min_extent( Range::make_by_min_extent(
0, state.at(r->parent)->extent)); 0, state.at(r->parent)->extent));
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
......
...@@ -37,6 +37,21 @@ def test_if(): ...@@ -37,6 +37,21 @@ def test_if():
assert isinstance(body.then_case.index, tvm.expr.Var) assert isinstance(body.then_case.index, tvm.expr.Var)
assert body.else_case.index.value == 0 assert body.else_case.index.value == 0
def test_prefetch():
A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.make.Prefetch(
A.op, A.value_index, A.dtype,
[tvm.make.range_by_min_extent(i+1, 2),
tvm.make.range_by_min_extent(0, 20)]))
body = ib.get()
assert body.body.bounds[0].extent.value == 2
if __name__ == "__main__": if __name__ == "__main__":
test_prefetch()
test_if() test_if()
test_for() test_for()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment