Commit 6d798778 by Tianqi Chen Committed by GitHub

[LANG/SCHEDULE] Reduction factor, predicate in reduction. (#77)

parent 7cd6b35d
......@@ -40,6 +40,8 @@ using Halide::Internal::as_const_uint;
using Halide::Internal::const_true;
using Halide::Internal::const_false;
using Halide::Internal::is_no_op;
using Halide::likely;
using Halide::likely_if_innermost;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
......
......@@ -41,7 +41,7 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src,
Array<IterVar> rdom,
Expr condition = make_const(Bool(1), true));
Expr condition = const_true());
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
......
......@@ -211,6 +211,18 @@ class Schedule : public NodeRef {
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction
* over the factored tensor.
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensor.
*/
Tensor rfactor(const Tensor& tensor,
const IterVar& axis);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
......
......@@ -19,4 +19,4 @@ from .ndarray import cpu, gpu, opencl, cl, vpi
from ._base import TVMError
from .api import *
from .build import build
from .build import build, lower
......@@ -372,7 +372,7 @@ def reduce_axis(dom, name="rv"):
return _IterVar(dom, name, 2)
def sum(expr, axis):
def sum(expr, axis, where=None):
"""Create a sum expression over axis
Parameters
......@@ -382,13 +382,16 @@ def sum(expr, axis):
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
"""
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis)
x = _make.Reduce("Add", expr, axis, where)
return x
def min(lhs, rhs=None, axis=None):
def min(lhs, rhs=None, axis=None, where=None):
"""Create a min expression.
Parameters
......@@ -401,6 +404,9 @@ def min(lhs, rhs=None, axis=None):
axis : IterVar, optional
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
......@@ -409,11 +415,11 @@ def min(lhs, rhs=None, axis=None):
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
x = _make.Reduce("Min", expr, axis, where)
return x
def max(lhs, rhs=None, axis=None):
def max(lhs, rhs=None, axis=None, where=None):
"""Create a max expression.
Parameters
......@@ -426,6 +432,9 @@ def max(lhs, rhs=None, axis=None):
axis : IterVar, optional
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
......@@ -434,7 +443,7 @@ def max(lhs, rhs=None, axis=None):
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
x = _make.Reduce("Max", expr, axis, where)
return x
......
......@@ -9,16 +9,15 @@ from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import collections
from . import codegen
def build(sch,
def lower(sch,
args,
target,
target_host="stackvm",
name="default_function",
binds=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.
"""Lowering step before build into target.
Parameters
----------
......@@ -28,12 +27,6 @@ def build(sch,
args : list of Buffer or Tensor or Var
The argument lists to the function.
target : str
The target of the compilation.
target_host :
Host compilation target, if target is device.
name : str
The name of result function.
......@@ -46,10 +39,8 @@ def build(sch,
Returns
-------
f : Function, or pair of functions
f : LoweredFunc
The result function.
If the function requires host space allocation,
a pair of functions will be returned.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
......@@ -77,6 +68,62 @@ def build(sch,
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0)
return fapi
def build(sch,
args=None,
target="llvm",
target_host="stackvm",
name="default_function",
binds=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.
Parameters
----------
sch : tvm.Schedule, or LoweredFunc
The schedule to be builded
args : list of Buffer or Tensor or Var
The argument lists to the function.
target : str
The target of the compilation.
target_host :
Host compilation target, if target is device.
name : str
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int
Maximum step to perform automatic unrolling
Returns
-------
f : Function, or pair of functions
The result function.
"""
if isinstance(sch, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
fapi = lower(sch, args,
name=name,
binds=binds,
max_auto_unroll_step=max_auto_unroll_step)
elif isinstance(sch, collections.LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc")
fapi = sch
else:
raise ValueError("sch have to be Schedule or LoweredFunc")
fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
......
......@@ -87,6 +87,27 @@ class Schedule(NodeBase):
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
def rfactor(self, tensor, axis):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction
over the factored tensor.
Parameters
----------
tensor : Tensor
The tensor to be factored.
axis : IterVar
The reduction axis in the schedule to be factored.
Returns
-------
tfactor : Tensor
The created factored tensor.
"""
return _api_internal._ScheduleRFactor(self, tensor, axis)
@register_node
class Stage(NodeBase):
......@@ -114,8 +135,6 @@ class Stage(NodeBase):
The inner variable of iteration.
"""
if outer is not None:
if outer.thread_tag == '':
raise ValueError("split by outer must have special thread_tag")
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
else:
if factor is None:
......
......@@ -89,7 +89,7 @@ TVM_REGISTER_API(_make_Allocate)
*ret = Node::make(a, b); \
})
REGISTER_MAKE3(Reduce);
REGISTER_MAKE4(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
......
......@@ -318,4 +318,10 @@ TVM_REGISTER_API(_ScheduleCacheWrite)
.cache_write(args[1], args[2]);
});
TVM_REGISTER_API(_ScheduleRFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.rfactor(args[1], args[2]);
});
} // namespace tvm
......@@ -526,8 +526,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
this->PrintIndent();
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
stream << " = " << value << ";\n";
} else {
......
......@@ -28,7 +28,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->source);
p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) {
p->stream << ", condition=" << op->condition;
p->stream << ", where=" << op->condition;
}
p->stream << ")";
});
......@@ -45,6 +45,9 @@ Expr Reduce::make(std::string op, Expr source,
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
}
if (!condition.defined()) {
condition = const_true();
}
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) {
......
/*!
* Copyright (c) 2016 by Contributors
* \file operation.cc
*/
#include <tvm/operation.h>
#include <tvm/tensor.h>
......@@ -10,6 +10,7 @@
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
#include "../schedule/message_passing.h"
namespace tvm {
......@@ -64,10 +65,7 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
args.push_back(axis.back()->var);
}
op_node->axis = Array<IterVar>(axis);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
return ComputeOpNode::make(name, axis, fcompute(args)).output(0);
}
Operation ComputeOpNode::make(std::string name,
......@@ -191,6 +189,9 @@ void MakeReduction(const ComputeOpNode* op,
}
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
if (!is_one(reduce->condition)) {
*provide = IfThenElse::make(reduce->condition, *provide);
}
}
Stmt MakeProvide(const ComputeOpNode* op,
......@@ -202,31 +203,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body, args);
}
// message passing to find if IterVar is related to reduction.
void PassDownReduceFlag(const Stage& s,
std::unordered_map<IterVar, int>* p_state) {
auto& state = *p_state;
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
int flag = state.at(s->parent);
state[s->outer] = flag;
state[s->inner] = flag;
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
int flag_outer = state.at(s->outer);
int flag_inner = state.at(s->inner);
state[s->fused] = flag_outer | flag_inner;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
int flag = state.at(s->parent);
state[s->rebased] = flag;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
......@@ -267,7 +243,7 @@ Stmt ComputeOpNode::BuildProvide(
update_state[iv] = 1;
}
// find which iter var is related to reduction and which is related to axis.
PassDownReduceFlag(stage, &update_state);
schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
......
......@@ -8,6 +8,7 @@
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include "./op_util.h"
#include "../schedule/message_passing.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
......@@ -16,61 +17,6 @@ namespace op {
using namespace arith;
using namespace ir;
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
* \param s The schedule to be used.
* \param dom_map The domain map of each iteration variable's domain
* \param p_state The message passing state
* IterVar->The assignment.
*/
void PassUpOffset(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
Expr outer = state.at(s->outer);
Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
state[s->parent] = inner + outer * factor;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = state[s->parent] + parent_min;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
}
if (!is_zero(inner_min)) {
state[s->inner] = state[s->inner] + inner_min;
}
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
Expr value = state.at(s->rebased);
Expr parent_min = dom_map.at(s->parent)->min;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = value + parent_min;
} else {
state[s->parent] = value;
}
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
......@@ -166,7 +112,7 @@ MakeLoopNest(const Stage& stage,
}
}
// message passing to get offset of root iter vars.
PassUpOffset(stage, dom_map, &value_map);
schedule::PassUpIndex(stage, dom_map, &value_map);
return nest;
}
......
......@@ -3,200 +3,18 @@
* \file bound.cc
* \brief The bound inference logic.
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/operation.h>
#include <unordered_map>
#include <unordered_set>
#include "./graph.h"
#include "./message_passing.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace schedule {
using namespace arith;
// result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b);
}
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
// Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the stage hyper graph will have a range(domain)
void PassDown(const Stage& s,
std::unordered_map<IterVar, Range>* p_state) {
auto& state = *p_state;
// forwar iteration on relations
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* r = rel.as<SplitNode>();
CHECK(state.count(r->parent));
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor);
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
} else {
if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< r->outer
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same "
<< outer_ext << " vs " << outer_rng->extent;
}
}
} else {
CHECK(r->outer->dom.defined());
state[r->outer] = r->outer->dom;
state[r->inner] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent));
}
} else if (rel.as<FuseNode>()) {
const FuseNode* r = rel.as<FuseNode>();
CHECK(state.count(r->outer));
CHECK(state.count(r->inner));
const Range& range_outer = state.at(r->outer);
const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent(
0, range_outer->extent * range_inner->extent);
} else if (rel.as<RebaseNode>()) {
const RebaseNode* r = rel.as<RebaseNode>();
CHECK(state.count(r->parent));
state[r->rebased] = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
// upward message passing algorithm
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
outer.match_range(dom_map.at(s->outer)) &&
inner.match_range(dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*parent = EvalSet(
s->outer->var * factor + s->inner->var + parent_min,
{{s->outer, outer}, {s->inner, inner}});
}
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (fused.match_range(dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
if (fused.is_single_point()) {
Expr value = fused.point_value();
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (rebased.match_range(dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = EvalSet(s->rebased->var + parent_min,
{{s->rebased, rebased}});
}
void PassUp(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
IntSet parent;
const SplitNode* r = rel.as<SplitNode>();
PassUp(r, dom_map,
state.at(r->outer), state.at(r->inner),
&parent);
state[r->parent] = parent;
} else if (rel.as<FuseNode>()) {
IntSet outer, inner;
const FuseNode* r = rel.as<FuseNode>();
PassUp(r, dom_map,
state.at(r->fused),
&outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else if (rel.as<RebaseNode>()) {
IntSet parent;
const RebaseNode* r = rel.as<RebaseNode>();
PassUp(r, dom_map,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
// check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
using runtime::ThreadScope;
......@@ -285,7 +103,7 @@ void InferRootBound(const Stage& stage,
}
}
// get the bound of the root IterVars given current location.
PassUp(parent, *rmap, &up_state);
PassUpDomain(parent, *rmap, &up_state);
std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : parent->op->root_iter_vars()) {
......@@ -358,7 +176,7 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, attach_path, &ret);
// pass down to get bound of all iter vars.
PassDown(stage, &ret);
PassDownDomain(stage, &ret);
// setup outer most threads.
for (IterVar iv : stage->outermost_threads) {
CHECK(iv->dom.defined());
......
/*!
* Copyright (c) 2017 by Contributors
* \file message_passing.h
* \brief Common utilities to do message passing
* on the schedule hyper graph.
*/
#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_
#define TVM_SCHEDULE_MESSAGE_PASSING_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace tvm {
namespace schedule {
/*!
* \brief Downward inference of domain of each IterVar.
* Caller set the range of the root, then the function
* propagates it towards the leaves.
*
* \param stage The stage to operate on.
* \param p_state The state of the message passing.
* \param allow_missing Whether allow missing value.
*/
void PassDownDomain(
const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
bool allow_missing = false);
/*!
* \param Upward inference of index of each IterVar.
* given index assignement of the leaves,
*
* \param stage The stage to operate on.
* \param dom_map The domain map of each iteration variable's domain.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
void PassUpIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state,
bool allow_missing = false);
/*!
* \param Upward inference of domain set of each IterVar.
* given domain assignment of the leaves,
*
* \param stage The stage to operate on.
* \param dom_map The domain map of each iteration variable's maximum domain.
* \param p_state The index state of each IterVar.
*/
void PassUpDomain(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state);
/*!
* \brief Upward message passing of bitmask with or relation.
* \param stage The stage to operate on.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
void PassUpBitMaskOr(const Stage& stage,
std::unordered_map<IterVar, int>* p_state,
bool allow_missing = false);
/*!
* \brief Downward message passing of bitmask with or relation.
* \param stage The stage to operate on.
* \param p_state The index state of each IterVar.
* \param allow_missing Whether allow missing value.
*/
void PassDownBitMaskOr(const Stage& stage,
std::unordered_map<IterVar, int>* p_state,
bool allow_missing = false);
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_MESSAGE_PASSING_H_
......@@ -7,6 +7,7 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./message_passing.h"
namespace tvm {
......@@ -139,7 +140,6 @@ Tensor Schedule::cache_write(const Tensor& tensor,
return cache_tensor;
}
void RebaseNonZeroMinLoop(const Schedule& sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
......@@ -244,4 +244,151 @@ void Schedule::normalize() {
InjectInline(*this);
}
// Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op);
const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
CHECK(compute_op) << "Can only factor ComputeOp";
ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
{
size_t axis_pos = FindNodeRef(leaf_vars, axis);
CHECK_NE(axis_pos, leaf_vars->data.size())
<< "Cannot find IterVar " << axis << " in leaf iter vars";
}
// Find touched reduction axis.
std::unordered_map<IterVar, int> touch_map;
touch_map[axis] = 1;
schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
// Verify normal axis are not touched.
for (IterVar iv : compute_op->axis) {
CHECK(!touch_map.count(iv))
<< "Factor axis touches normal axis.";
}
// Get the replace index
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<IterVar, Expr> value_map;
for (IterVar iv : compute_op->reduce_axis) {
if (touch_map.count(iv)) dom_map[iv] = iv->dom;
}
schedule::PassDownDomain(reduce_stage, &dom_map, true);
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv)) {
Range dom = dom_map.at(iv);
if (is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = iv->var;
}
}
}
schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
// Get the factored op node.
auto n = std::make_shared<ComputeOpNode>();
n->name = compute_op->name + ".rf";
{
// axis relacement.
auto iv_node = std::make_shared<IterVarNode>();
iv_node->dom = dom_map.at(axis);
CHECK(is_zero(iv_node->dom->min))
<< "Can only factor reduction domain starting from 0";
iv_node->var = axis->var;
iv_node->iter_type = kDataPar;
n->axis.push_back(IterVar(iv_node));
for (IterVar iv : compute_op->axis) {
n->axis.push_back(iv);
}
}
// predicate generation, copy not touched axis.
std::unordered_map<const Variable*, Expr> vsub;
Expr predicate;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv);
} else {
CHECK(value_map.count(iv));
Expr index = value_map.at(iv);
vsub[iv->var.get()] = index;
if (!index.same_as(iv->var)) {
Expr cond = (index < dom_map.at(iv)->extent);
if (predicate.defined()) {
predicate = predicate && cond;
} else {
predicate = cond;
}
}
}
}
// Copy touched axis.
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) {
CHECK_EQ(iv->iter_type, kCommReduce);
auto ncpy = std::make_shared<IterVarNode>(*iv.operator->());
ncpy->dom = dom_map.at(iv);
n->reduce_axis.push_back(IterVar(ncpy));
}
}
const Reduce* reduce = compute_op->body.as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
n->body = Reduce::make(reduce->op,
VarReplacer(vsub).Mutate(reduce->source),
n->reduce_axis,
predicate);
// refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) {
bool touched = false;
if (const SplitNode* r = rel.as<SplitNode>()) {
if (touch_map.count(r->parent)) touched = true;
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (touch_map.count(r->fused)) touched = true;
} else if (const RebaseNode* r = rel.as<RebaseNode>()) {
if (touch_map.count(r->parent)) touched = true;
} else {
LOG(FATAL) << "unknown relation type";
}
if (!touched) {
rels.push_back(rel);
}
}
// initialize the factored stage.
Operation factor_op(n);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t stage_pos = FindNodeRef(stages, reduce_stage);
Stage factor_stage = Stage(factor_op);
factor_stage->relations = rels;
CHECK_LT(stage_pos, stages->data.size());
stages->data.insert(stages->data.begin() + stage_pos,
factor_stage.node_);
(*this)->stage_map.Set(factor_op, factor_stage);
// Replace the old reduction.
IterVar repl_red_axis = reduce_axis(
dom_map.at(axis), axis->var->name_hint + ".v");
Tensor factor_tensor = factor_op.output(0);
Tensor old_tensor = reduce_stage->op.output(0);
Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) {
Array<Expr> indices;
indices.push_back(repl_red_axis->var);
for (Var v : i) {
indices.push_back(v);
}
return Reduce::make(
reduce->op, factor_tensor(indices), {repl_red_axis}, const_true());
}, old_tensor->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
vmap[old_tensor] = repl_tensor;
ReplaceDataFlow((*this)->stages, &vmap);
// revamp the reduction stage.
reduce_stage->op = repl_tensor->op;
reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars();
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
reduce_stage->relations = Array<IterVarRelation>();
return factor_tensor;
}
} // namespace tvm
......@@ -43,11 +43,18 @@ void CheckSplit(StageNode* self, IterVar parent, IterVar outer) {
<< "Cannot split on axis[0] of scan update";
}
if (outer.defined()) {
CHECK_EQ(outer->iter_type, kThreadIndex)
<< "outer in split have to be ThreadIndex";
CHECK_EQ(parent->iter_type, kDataPar)
<< "Split by by kThreadIndex requires kDataPar IterVar "
<< " given " << IterVarType2String(parent->iter_type);
if (outer->iter_type == kThreadIndex) {
CHECK_EQ(parent->iter_type, kDataPar)
<< "Split by by kThreadIndex requires kDataPar IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else if (outer->iter_type == kCommReduce) {
CHECK_EQ(parent->iter_type, kCommReduce)
<< "Split by by kCommReduce requires kCommReduce IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else {
LOG(FATAL) << "Cannot take " << IterVarType2String(parent->iter_type)
<< " as outer IterVar";
}
} else {
CHECK(parent->iter_type == kDataPar ||
parent->iter_type == kCommReduce ||
......@@ -73,18 +80,6 @@ void Split(StageNode* self, IterVar parent,
} // namespace
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
p->stream << "stage("
<< op->op
<< ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
p->stream << IterVarType2String(op->iter_type);
});
Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>();
n->op = op;
......@@ -374,4 +369,42 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
})
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
p->stream << IterVarType2String(op->iter_type);
})
.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
p->stream << "split(parent=";
p->print(op->parent);
p->stream << ", outer=";
p->print(op->outer);
p->stream << ", inner=";
p->print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
p->stream << "split(";
p->stream << "outer=";
p->print(op->outer);
p->stream << ", inner=";
p->print(op->inner);
p->stream << ", fused=";
p->print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
p->stream << "rebase(";
p->stream << "parent=";
p->print(op->parent);
p->stream << ", rebased=";
p->print(op->rebased);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
p->stream << "schedule(" << op << ")";
});
} // namespace tvm
......@@ -7,7 +7,7 @@ def test_sum():
m = tvm.Var('m')
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
# schedule
s = tvm.Schedule(B.op)
# create iter var and assign them tags.
......@@ -28,14 +28,17 @@ def test_sum():
args=[A, B],
target=device, target_host=host,
name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel.
n = 1028
m = 129
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=1)
res[:2] = 0
np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
b.asnumpy(), res, rtol=1e-4)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
......@@ -43,5 +46,38 @@ def test_sum():
check_device("cuda")
check_device("opencl")
def test_rfactor():
n = tvm.convert(1027)
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf)
BF = s.rfactor(B, kf)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
if not tvm.codegen.enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi,
target=target,
name="mysum")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
check_target()
if __name__ == "__main__":
test_rfactor()
test_sum()
......@@ -91,8 +91,33 @@ def test_vectorize():
assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
def test_rfactor():
n = tvm.Var('n')
k1 = tvm.reduce_axis((0, n), name="k1")
k2 = tvm.reduce_axis((0, n), name="k2")
A = tvm.placeholder((n, n, n), name='A')
B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k1, k2], axis=[k1, k2]))
# normal schedule
s = tvm.Schedule(B.op)
BF = s.rfactor(B, k1)
assert(tuple(BF.shape) == (n, n))
assert(set(BF.op.body.axis) == set([k2]))
assert(s[B].op.body.axis[0].dom.extent == n)
assert(len(s[B].all_iter_vars) == 2)
# schedule with splot
s = tvm.Schedule(B.op)
ko, ki = s[B].split(k1, factor=4)
xo, xi = s[B].split(B.op.axis[0], factor=8)
BF = s.rfactor(B, ki)
assert(BF.shape[0].value == 4)
assert(BF.shape[1] == n)
assert(BF.op.body.axis[0] == k2)
assert(BF.op.body.axis[1].var == ko.var)
assert(s[B].op.body.axis[0].dom.extent.value == 4)
if __name__ == "__main__":
test_rfactor()
test_schedule_create()
test_reorder()
test_tile()
......
......@@ -100,7 +100,24 @@ def test_bound_blur():
assert(bounds[A.op.axis[0]].extent.value == 3)
assert(bounds[A.op.axis[1]].extent.value == 3)
def test_bound_rfactor():
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf)
BF = s.rfactor(B, kf)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4)
assert(bounds[BF.op.axis[1]].extent.value == 1)
if __name__ == "__main__":
test_bound_rfactor()
test_bound_blur()
test_bound_conv1d()
test_bound_scan()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment