Commit 3f3bf29d by Tianqi Chen Committed by GitHub

[DOC] Make range related function consistent (#249)

parent 2ab0bfb5
......@@ -27,6 +27,7 @@ using Halide::ExprEqual;
using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::RangeNode;
using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
......@@ -113,7 +114,7 @@ class Range : public Halide::IR::Range {
*/
Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
static Range make_by_min_extent(Expr min, Expr extent);
};
/*!
......
......@@ -8,4 +8,24 @@ You can use make function to build the IR node.
"""
from ._ffi.function import _init_api
def range_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Parameters
----------
min_value : Expr
The minimum value of the range.
extent : Expr
The extent of the range.
Returns
-------
rng : Range
The constructed range.
"""
return _range_by_min_extent(min_value, extent)
_init_api("tvm.make")
......@@ -16,6 +16,11 @@ TVM_REGISTER_API("_Var")
*ret = Variable::make(args[1], args[0]);
});
TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
});
TVM_REGISTER_API("make.For")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = For::make(args[0],
......@@ -147,6 +152,7 @@ REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
......
......@@ -309,7 +309,7 @@ class Canonical::Internal : public IRMutator {
++level_counter_;
Var loop_var(op->loop_var.node_);
this->SetRange(loop_var,
Range::make_with_min_extent(op->min, op->extent),
Range::make_by_min_extent(op->min, op->extent),
level_counter_);
Stmt stmt = IRMutator::Mutate_(op, s);
--level_counter_;
......@@ -324,7 +324,7 @@ class Canonical::Internal : public IRMutator {
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_level_.count(iv->var.get())) {
this->SetRange(iv->var,
Range::make_with_min_extent(0, op->value),
Range::make_by_min_extent(0, op->value),
level_counter_);
}
Stmt stmt = IRMutator::Mutate_(op, s);
......
......@@ -40,7 +40,7 @@ Range IntSet::cover_range(Range max_range) const {
s_int = temp.as<IntervalSet>();
}
if (s_int->i.is_bounded()) {
return Range::make_with_min_extent(
return Range::make_by_min_extent(
s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
}
return max_range;
......
......@@ -22,7 +22,7 @@ Range::Range(Expr begin, Expr end)
is_zero(begin) ? end : (end - begin))) {
}
Range Range::make_with_min_extent(Expr min, Expr extent) {
Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
......
......@@ -93,7 +93,7 @@ void ExternOpNode::PropBoundToInputs(
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_with_min_extent(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])));
}
}
......@@ -116,7 +116,7 @@ Stmt ExternOpNode::BuildRealize(
Halide::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_with_min_extent(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
......
......@@ -76,7 +76,7 @@ Operation ScanOpNode::make(std::string name,
spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back(
IterVarNode::make(
Range::make_with_min_extent(0, update[i]->shape[k]),
Range::make_by_min_extent(0, update[i]->shape[k]),
Var(spatial_name.str()), kOpaque));
}
}
......@@ -104,7 +104,7 @@ Array<Tensor> scan(Array<Tensor> init,
std::string tag) {
IterVar scan_axis =
IterVarNode::make(
Range::make_with_min_extent(
Range::make_by_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make(
......@@ -165,7 +165,7 @@ void ScanOpNode::PropBoundToInputs(
// first dimension, always needed.
if (init_dom) {
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, this->init[i]->shape[0])));
Range::make_by_min_extent(0, this->init[i]->shape[0])));
}
if (update_dom) {
update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
......@@ -203,7 +203,7 @@ void ScanOpNode::GatherBound(
CHECK(!out_dom_map->count(this->scan_axis));
Range sdom = this->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent(
(*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis.
......@@ -231,7 +231,7 @@ Stmt ScanOpNode::BuildRealize(
const Stmt& body) const {
CHECK_EQ(self.operator->(), this);
Range sdom = dom_map.at(this->scan_axis);
Range tdom = Range::make_with_min_extent(
Range tdom = Range::make_by_min_extent(
0, ir::Simplify(sdom->extent + sdom->min));
Stmt ret = body;
size_t sp_idx = 0;
......
......@@ -32,7 +32,7 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make(
Range::make_with_min_extent(new_min, new_extent),
Range::make_by_min_extent(new_min, new_extent),
v->var, v->iter_type, v->thread_tag);
}
if (!changed) {
......
......@@ -171,8 +171,8 @@ class ChannelAccessRewriter : public IRMutator {
Channel ch(adv_op->node.node_);
ChannelAccessBound acc(ch->handle_var.get(), read_access);
IntSet iset = acc.Eval(for_op->body);
Range r = iset.cover_range(Range::make_with_min_extent(0, window));
r = Range::make_with_min_extent(
Range r = iset.cover_range(Range::make_by_min_extent(0, window));
r = Range::make_by_min_extent(
ir::Simplify(r->min), ir::Simplify(r->extent));
if (ExprUseVar(r->extent, var)) return body;
Array<Expr> linear_eq = DetectLinearEquation(r->min, var);
......
......@@ -53,14 +53,14 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
Update(p_state, r->inner, Range::make_with_min_extent(0, r->factor));
Update(p_state, r->inner, Range::make_by_min_extent(0, r->factor));
Update(p_state, r->outer,
Range::make_with_min_extent(
Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->factor)));
} else {
Update(p_state, r->outer, Range::make_with_min_extent(0, r->nparts));
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts));
Update(p_state, r->inner,
Range::make_with_min_extent(
Range::make_by_min_extent(
0, DivCeil(range_parent->extent, r->nparts)));
}
} else if (const FuseNode* r = rel.as<FuseNode>()) {
......@@ -70,7 +70,7 @@ void PassDownDomain(const Stage& stage,
}
const Range& range_outer = state.at(r->outer);
const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent(
state[r->fused] = Range::make_by_min_extent(
0, range_outer->extent * range_inner->extent);
} else if (const RebaseNode* r = rel.as<RebaseNode>()) {
if (!state.count(r->parent)) {
......@@ -78,7 +78,7 @@ void PassDownDomain(const Stage& stage,
continue;
}
Update(p_state, r->rebased,
Range::make_with_min_extent(
Range::make_by_min_extent(
0, state.at(r->parent)->extent));
} else {
LOG(FATAL) << "unknown relation type";
......
......@@ -37,6 +37,21 @@ def test_if():
assert isinstance(body.then_case.index, tvm.expr.Var)
assert body.else_case.index.value == 0
def test_prefetch():
A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.make.Prefetch(
A.op, A.value_index, A.dtype,
[tvm.make.range_by_min_extent(i+1, 2),
tvm.make.range_by_min_extent(0, 20)]))
body = ib.get()
assert body.body.bounds[0].extent.value == 2
if __name__ == "__main__":
test_prefetch()
test_if()
test_for()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment