Commit f467f66e by ziheng Committed by GitHub

Support for Tuple Inputs of Reducer and ComputeOp (#175)

* Support for batch ComputeOp

* Support for batch ComputeOp

* Fix CrossThreadReduction

* Fix lint

* Add UpdateArray, remove support for batch reduce

* Tuple input support for reduce

* rfactor works with multiple reducer; support multiple reducers with different types

* Small fix

* Small fix

* Change return type of rfactor to Array<Expr>

* Fix lint

* Improve

* Add tutorial

* Improve tutorial

* Improve tutorial
parent ef50162b
......@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */
Array<Var> args;
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */
Expr result;
Array<Expr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
Expr identity_element;
Array<Expr> identity_element;
/*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const;
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
......@@ -84,7 +88,7 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
Expr source;
Array<Expr> source;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
......@@ -92,18 +96,22 @@ struct Reduce : public ExprNode<Reduce> {
* Only add the body to reduction if condition is true.
*/
Expr condition;
/*! \brief the index of this reduce node */
int value_index;
/*! \brief construct expr from op and rdom */
static Expr make(CommReducer combiner,
Expr src,
Array<Expr> src,
Array<IterVar> rdom,
Expr condition = const_true());
Expr condition,
int value_index);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("source", &source);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
......@@ -292,10 +300,11 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*!
* \brief See pesudo code
*
* Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond,
* Var thread_idx1, thread_idx2...) {
* void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
* Var reduce_temp0, .., Var thread_idx1, ...) {
* // constraint by the other thread_idx remain the same.
* return reduce(combiner, value, cond,
* // reduce_temp is used to save intermediate result.
* reduce_temp0, ... = reduce(combiner, source0, ..., cond
* over [thread_idx1, thread_idx2] passed by any caller)
* }
*/
......
......@@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
......
......@@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
Array<Expr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
......@@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
}
static Operation make(std::string name,
Array<IterVar> axis,
Expr body);
Array<Expr> body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
......@@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
......@@ -378,6 +381,15 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
*/
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan.
*
* \param init The intialize tensor of first K steps.
......
......@@ -252,14 +252,14 @@ class Schedule : public NodeRef {
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction
* as the first dimension. The tensor's body will be rewritten as a reduction
* over the factored tensor.
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensor.
* \return The created factored tensors.
*/
Tensor rfactor(const Tensor& tensor,
Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis);
/*!
* \brief Normalize the schedule.
......
......@@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"):
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return op_node.output(0)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan"):
......@@ -525,18 +529,45 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return res
def _make_reduce(expr, axis, where=None):
expr = convert(expr)
dtype = expr.dtype
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
arg_vars = [var(name, dtype) for name in code.co_varnames]
result = fcombine(*[v for v in arg_vars])
expr = convert(expr)
if isinstance(expr, _collections.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.Expr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = fidentity(dtype)
assert isinstance(id_elem, _expr.Expr)
combiner = _make.CommReducer(arg_vars, result, id_elem)
axis = axis if isinstance(axis, list) else [axis]
return _make.Reduce(combiner, expr, axis, where)
id_elem = convert(id_elem)
combiner = _make.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, list) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)):
......
......@@ -181,7 +181,7 @@ class Schedule(NodeBase):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction
as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor.
Parameters
......@@ -193,10 +193,11 @@ class Schedule(NodeBase):
Returns
-------
tfactor : Tensor
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
return _api_internal._ScheduleRFactor(self, tensor, axis)
factored = _api_internal._ScheduleRFactor(self, tensor, axis)
return factored[0] if len(factored) == 1 else factored
@register_node
......
......@@ -69,10 +69,12 @@ TVM_REGISTER_API("make.Call")
TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
*ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \
......@@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(a, b); \
})
REGISTER_MAKE4(Reduce);
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
......
......@@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
......
......@@ -9,6 +9,7 @@
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
#include "../pass/ir_util.h"
namespace Halide {
namespace Internal {
......@@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
<< op->combiner
<< ", ";
p->print(op->source);
<< op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) {
p->stream << ", where=" << op->condition;
}
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result="
<< op->result
<< ", args=" << op->args
<< ", identity_element="
<< op->identity_element
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
<< ", rhs=" << op->rhs
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
......@@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {
CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) {
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element) {
auto node = std::make_shared<CommReducerNode>();
node->args = args;
node->lhs = lhs;
node->rhs = rhs;
node->result = result;
node->identity_element = identity_element;
return CommReducer(node);
}
Expr CommReducerNode::operator()(Expr a, Expr b) const {
Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
Map<Var, Expr> value_map;
value_map.Set(args[0], a);
value_map.Set(args[1], b);
return Substitute(result, value_map);
for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
return UpdateArray(result, [&value_map] (const Expr& e) {
return Substitute(e, value_map);
});
}
Expr Reduce::make(CommReducer combiner, Expr source,
Array<IterVar> axis, Expr condition) {
Expr Reduce::make(CommReducer combiner, Array<Expr> source,
Array<IterVar> axis, Expr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
......@@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source,
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->combiner = combiner;
n->source = source;
n->axis = axis;
n->type = source[value_index].type();
n->combiner = std::move(combiner);
n->source = std::move(source);
n->axis = std::move(axis);
n->condition = condition;
n->value_index = value_index;
return Expr(n);
}
......
......@@ -24,7 +24,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
int ComputeOpNode::num_outputs() const {
return 1;
return body.size();
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
......@@ -36,13 +36,14 @@ Array<IterVar> ComputeOpNode::root_iter_vars() const {
return ret;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
Type ComputeOpNode::output_dtype(size_t idx) const {
CHECK_LT(idx, num_outputs());
return body[idx].type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
CHECK_LT(idx, num_outputs());
// for now, all outputs of ComputeOp have the same shape
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
......@@ -65,18 +66,55 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
args.push_back(axis.back()->var);
}
return ComputeOpNode::make(name, axis, fcompute(args)).output(0);
return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0);
}
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVarNode::make(
Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
args.push_back(axis.back()->var);
}
Operation op = ComputeOpNode::make(name, axis, fcompute(args));
Array<Tensor> outputs;
for (int idx = 0; idx < op->num_outputs(); ++idx) {
outputs.push_back(op.output(idx));
}
return outputs;
}
bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
Operation ComputeOpNode::make(std::string name,
Array<IterVar> axis,
Expr body) {
Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>();
n->name = name;
n->axis = axis;
n->body = body;
if (n->body->is_type<ir::Reduce>()) {
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
if (n->body[0]->is_type<ir::Reduce>()) {
const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
for (size_t i = 1; i < n->body.size(); ++i) {
const ir::Reduce* reduce_ = n->body[i].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
n->reduce_axis = reduce->axis;
}
return Operation(n);
}
......@@ -85,7 +123,8 @@ Operation ComputeOpNode::make(std::string name,
Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
std::unordered_set<Tensor> visited;
ir::PostOrderVisit(body, [&ret, &visited](const NodeRef& n) {
for (auto& e : body) {
ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
......@@ -95,6 +134,7 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
}
}
});
}
return ret;
}
......@@ -102,9 +142,11 @@ Operation ComputeOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
Expr new_body = op::ReplaceTensor(this->body, rmap);
if (!new_body.same_as(this->body)) {
return ComputeOpNode::make(name, axis, new_body);
Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
return op::ReplaceTensor(e, rmap);
});
if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, axis, arr);
} else {
return self;
}
......@@ -127,7 +169,7 @@ void ComputeOpNode::PropBoundToInputs(
}
}
};
ir::PostOrderVisit(body, fvisit);
for (auto& e : body) ir::PostOrderVisit(e, fvisit);
}
void ComputeOpNode::GatherBound(
......@@ -151,34 +193,50 @@ Stmt ComputeOpNode::BuildRealize(
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& realize_body) const {
CHECK_EQ(self.operator->(), this);
Tensor t = self.output(0);
Halide::Internal::Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
return ir::Realize::make(t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
Stmt realize = realize_body;
for (int i = self->num_outputs(); i > 0; --i) {
Tensor t = self.output(i-1);
realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
}
return realize;
}
// Build a reduction body.
void MakeReduction(const ComputeOpNode* op,
const Tensor& t,
const Array<Tensor>& tensors,
Stmt* init,
Stmt* provide) {
Stmt no_op = Evaluate::make(0);
std::vector<Stmt> nest;
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
const Reduce* reduce = op->body.as<Reduce>();
std::vector<Stmt> inits, provides;
size_t size = op->body.size();
const Reduce* reduce = op->body[0].as<Reduce>();
CHECK(reduce);
const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
CHECK(combiner);
Expr init_value = combiner->identity_element;
Expr update_value = (*combiner)(t(args), reduce->source);
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
Array<Expr> lhs;
for (size_t i = 0; i < size; ++i) {
lhs.push_back(tensors[i](args));
}
Array<Expr> init_value = combiner->identity_element;
Array<Expr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(Provide::make(
t->op, t->value_index, init_value[i], args));
provides.emplace_back(Provide::make(
t->op, t->value_index, update_value[i], args));
}
*init = Block::make(inits);
*provide = Block::make(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElse::make(reduce->condition, *provide);
}
......@@ -225,22 +283,36 @@ Stmt MakeCrossThreadReduction(
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
const Reduce* reduce = self->body.as<Reduce>();
CHECK(reduce);
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
Expr cond = reduce->condition;
size_t size = self->body.size();
CHECK_GT(size, 0);
std::vector<const Reduce*> reduces(size);
for (size_t i = 0; i < size; ++i) {
const Reduce* reduce = self->body[i].as<Reduce>();
CHECK(reduce);
reduces[i] = reduce;
}
Expr cond = reduces[0]->condition;
for (Expr v : conds) {
cond = cond && v;
}
Var res_handle("reduce_temp", Handle());
Array<Expr> freduce_args;
freduce_args.push_back(reduce->source);
freduce_args.push_back(make_const(UInt(32), size));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
freduce_args.push_back(res_handles[idx]);
}
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
......@@ -257,28 +329,33 @@ Stmt MakeCrossThreadReduction(
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
Type t = reduce->type;
Expr pred = const_true(t.lanes());
Stmt reduce_body = Store::make(res_handle,
Call::make(
reduce->type,
Stmt reduce_body = Evaluate::make(Call::make(
Handle(),
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic),
0, pred);
freduce_args, Call::Intrinsic));
reduce_body = AttrStmt::make(
reduce->combiner,
reduces[0]->combiner,
attr::reduce_scope,
make_zero(reduce->type),
make_zero(Handle()),
reduce_body);
Stmt assign_body = Provide::make(
stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
Type t = reduces[idx]->type;
assigns[idx] = Provide::make(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Allocate::make(
res_handle, reduce->type, {1}, const_true(),
Block::make(reduce_body, assign_body));
Stmt body = Block::make(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
body = AttrStmt::make(
res_handle, attr::storage_scope, StringImm::make("local"), body);
res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
}
body = Substitute(body, value_map);
return MergeNest(nest, body);
}
......@@ -289,7 +366,7 @@ Stmt MakeProvide(const ComputeOpNode* op,
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return Provide::make(t->op, t->value_index, op->body, args);
return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
}
Stmt ComputeOpNode::BuildProvide(
......@@ -301,12 +378,24 @@ Stmt ComputeOpNode::BuildProvide(
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
Stmt init, provide;
size_t size = this->body.size();
Stmt init;
Stmt provide;
if (this->reduce_axis.size() == 0) {
provide = MakeProvide(this, stage->op.output(0));
std::vector<Stmt> provides;
for (size_t i = 0; i < size; ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
provide = Block::make(provides);
} else {
MakeReduction(this, stage->op.output(0), &init, &provide);
Array<Tensor> source;
for (size_t i = 0; i < size; ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
}
// make loop nest
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
namespace tvm {
namespace ir {
......@@ -17,19 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
std::vector<Expr> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = m->Mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<Expr>(new_arr);
}
return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); });
}
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
......@@ -323,14 +312,15 @@ DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Expr new_source = this->Mutate(op->source);
Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source) &&
op->condition.same_as(new_cond)) {
return e;
} else {
return Reduce::make(op->combiner, new_source, new_axis, new_cond);
return Reduce::make(
op->combiner, new_source, new_axis, new_cond, op->value_index);
}
}
......
......@@ -13,6 +13,32 @@ namespace tvm {
namespace ir {
/*!
* \brief update array with an unary function
* \param arr array
* \param fupdate an unary function
* \tparam T type of array element
* \tparam F type of the unary function
* \return if update happens, return the new array, else return the
* original array
*/
template<typename T, typename F>
inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
std::vector<T> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); ++i) {
T old_elem = arr[i];
T new_elem = fupdate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<T>(new_arr);
}
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
......
......@@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
this->Visit(op->source);
VisitArray(op->source, this);
}
void IRVisitor::Visit_(const Cast* op) {
......
......@@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
op = stmt.as<Evaluate>();
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
return MakeAllreduce(op, call);
return MakeAllreduce(call);
} else {
return stmt;
}
......@@ -97,18 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator {
}
};
// make allreduce.
Stmt MakeAllreduce(const Store* op, const Call* call) {
Stmt MakeAllreduce(const Call* call) {
CHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back();
Expr init = combiner->identity_element;
Expr value = call->args[0];
Expr cond = call->args[1];
size_t size = combiner->result.size();
const UIntImm *size_of_args = call->args[0].as<UIntImm>();
CHECK(size_of_args) << call->args[0]->type_key();
CHECK_EQ(size, size_of_args->value);
Array<Expr> inits = combiner->identity_element;
std::vector<Expr> values(size);
std::vector<Type> types(size);
Expr cond = call->args[size+1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1+idx];
if (!is_one(cond)) {
value = Select::make(cond, value, init);
values[idx] = Select::make(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].type();
}
std::vector<const Variable*> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
const Variable* buffer = call->args[2+size+idx].as<Variable>();
CHECK(buffer);
buffers[idx] = buffer;
}
std::unordered_set<const Variable*> reduce_set;
for (size_t i = 2; i < call->args.size(); ++i) {
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>();
CHECK(v);
reduce_set.insert(v);
......@@ -143,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator {
int threadx_extent = 1;
Expr reduce_index = FlattenThread(vred, &reduce_extent);
Expr group_index = FlattenThread(vpar, &group_extent);
Expr pred = const_true(value.type().lanes());
if (reduce_extent == 1) {
// special case, no reduction is needed.
return Store::make(op->buffer_var, value, 0, pred);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
Expr pred = const_true(types[i].lanes());
Var buffer_var(call->args[2+size+i].node_);
stores[i] = Store::make(buffer_var, values[i], 0, pred);
}
return Block::make(stores);
}
// Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) {
threadx_extent = vred[0].extent;
}
Var shared_buf("red_buf", Handle());
std::vector<Stmt> seq;
std::vector<Var> shared_bufs(size);
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
Expr pred = const_true(types[idx].lanes());
seq.emplace_back(Store::make(
shared_buf, value,
shared_bufs[idx], values[idx],
BufIndex(reduce_index, group_index, reduce_extent), pred));
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(
combiner, value.type(), shared_buf,
combiner, types, shared_bufs,
reduce_index, group_index, reduce_extent, threadx_extent));
CHECK(!load_remap_.count(op->buffer_var.get()));
load_remap_[op->buffer_var.get()] =
Load::make(
value.type(), shared_buf,
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent),
pred);
alloc_remap_[op->buffer_var.get()] =
Allocate::make(shared_buf, value.type(),
for (size_t idx = 0; idx < size; ++idx) {
CHECK(!load_remap_.count(buffers[idx]));
Expr pred = const_true(types[idx].lanes());
load_remap_[buffers[idx]] = Load::make(
types[idx], shared_bufs[idx],
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
alloc_remap_[buffers[idx]] = Allocate::make(
shared_bufs[idx], types[idx],
{Expr(group_extent), Expr(reduce_extent)},
pred, Evaluate::make(0));
}
return MergeSeq(seq);
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
Type type,
Var shared_buf,
const std::vector<Type>& types,
const Array<Var>& shared_bufs,
Expr reduce_index,
Expr group_index,
int reduce_extent,
......@@ -189,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GT(reduce_align, 1);
std::vector<Stmt> seq;
size_t size = shared_bufs.size();
Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto freduce = [&](int offset) {
Expr b = Load::make(
type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent), const_true());
Expr a = Load::make(type, shared_buf, buf_index, const_true());
return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true());
Array<Expr> a, b;
for (size_t i = 0; i < size; ++i) {
b.push_back(Load::make(types[i], shared_bufs[i],
BufIndex(reduce_index + offset, group_index, reduce_extent),
const_true()));
a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
}
Array<Expr> ret = (*combiner)(a, b);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
}
return Block::make(stores);
};
// Step one, check for
if (reduce_align > reduce_extent) {
......
......@@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator {
CHECK_EQ(extern_buf_remap_.size(), 0U);
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
TensorKey key{func, static_cast<int>(i)};
CHECK(buf_map_.count(key));
CHECK(buf_map_.count(key))
<< "Cannot find allocated buffer for " << key.f
<< "(" << key.value_index << ")";
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
}
......
......@@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor {
bool IsElemWise(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis);
v.Visit(compute->body);
for (auto& e : compute->body) v.Visit(e);
return v.is_elem_wise_;
}
return false;
......
......@@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
for (auto& e : op.as<ComputeOpNode>()->body) {
ir::PostOrderVisit(e, fvisit);
}
}
}
return reach;
......@@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
} else if (op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap;
std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i);
std::vector<TensorDimKey> keys;
for (int j = 0; j < op->num_outputs(); ++j) {
keys.emplace_back(op.output(j), i);
}
vmap[axis[i]->var.get()] = std::move(keys);
}
auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
const NodeRef& n) {
......@@ -335,7 +340,10 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
auto it = vmap.find(call->args[i].get());
TensorDimKey src(call, static_cast<int>(i));
if (it != vmap.end()) {
f_merge_key(it->second, src);
const std::vector<TensorDimKey>& keys = it->second;
for (const auto& key : keys) {
f_merge_key(key, src);
}
} else {
if (exact_reach.count(src)) {
fail_set.insert(exact_reach.at(src));
......@@ -344,7 +352,9 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
for (auto& e : op.as<ComputeOpNode>()->body) {
ir::PostOrderVisit(e, fvisit);
}
}
}
ReachGraph reach;
......
......@@ -27,7 +27,7 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >;
/*!
* \brief The map beteen tensor and operation it feeds to.
* \brief The map between tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
......@@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots);
* The operations contains node which input-reachable from any inputs
* output reachable to any outputs.
*
* The inputs won't be included in the subgraph, the outputs will be inclued.
* The inputs won't be included in the subgraph, the outputs will be included.
*
* \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph.
......
......@@ -8,6 +8,7 @@
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./message_passing.h"
#include "../pass/ir_util.h"
namespace tvm {
......@@ -120,13 +121,13 @@ Tensor Schedule::cache_write(const Tensor& tensor,
vsub[iv->var.get()] = new_iv->var;
}
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, body);
compute->name + "." + scope, new_axis, {body});
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis,
cache_tensor(args));
{cache_tensor(args)});
std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
......@@ -198,14 +199,15 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size());
std::vector<Array<Expr>> new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready;
Array<Var> args;
Expr body;
Array<Expr> body;
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
......@@ -220,11 +222,14 @@ void InjectInline(ScheduleNode* sch) {
Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) {
if (!new_body[j].defined()) {
if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body;
}
new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]),
stage->op, args, body).as<ir::Evaluate>()->value;
for (size_t k = 0; k < body.size(); ++k) {
changed[j] = true;
new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
stage->op, args, body[k]).as<ir::Evaluate>()->value);
}
}
}
}
......@@ -234,20 +239,22 @@ void InjectInline(ScheduleNode* sch) {
for (size_t i = 0; i < sch->stages.size(); ++i) {
Stage s = sch->stages[i];
if (s->attach_type == kInlinedAlready) continue;
if (new_body[i].defined()) {
if (new_body[i].size()) {
// Logics from ReplaceDataFlow
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute);
Operation op = s->op;
if (!new_body[i].same_as(compute->body)) {
if (changed[i]) {
op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]);
}
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
repl[s->op.output(0)] = op.output(0);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
s->op = op;
}
}
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
if (!op.same_as(s->op)) {
......@@ -268,7 +275,7 @@ Schedule Schedule::normalize() {
}
// Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor,
Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
(*this)->InvalidateCache();
using ir::Reduce;
......@@ -329,7 +336,8 @@ Tensor Schedule::rfactor(const Tensor& tensor,
}
}
// predicate generation, copy not touched axis.
const Reduce* reduce = compute_op->body.as<Reduce>();
int idx = tensor->value_index;
const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
Expr predicate = reduce->condition;
std::unordered_map<const Variable*, Expr> vsub;
......@@ -359,10 +367,18 @@ Tensor Schedule::rfactor(const Tensor& tensor,
n->reduce_axis.push_back(IterVar(ncpy));
}
}
n->body = Reduce::make(reduce->combiner,
VarReplacer(vsub).Mutate(reduce->source),
VarReplacer replacer(vsub);
Array<Expr> new_source = ir::UpdateArray(reduce->source,
[&replacer] (const Expr& e) { return replacer.Mutate(e); });
std::vector<Expr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
body.emplace_back(Reduce::make(reduce->combiner,
new_source,
n->reduce_axis,
predicate);
predicate,
idx));
}
n->body = Array<Expr>(body);
// refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) {
......@@ -397,26 +413,44 @@ Tensor Schedule::rfactor(const Tensor& tensor,
// Replace the old reduction.
IterVar repl_red_axis = reduce_axis(
dom_map.at(axis), axis->var->name_hint + ".v");
Tensor factor_tensor = factor_op.output(0);
Tensor old_tensor = reduce_stage->op.output(0);
Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) {
Array<Tensor> factor_tensors;
Array<Tensor> old_tensors;
int size = factor_op->num_outputs();
for (int idx = 0; idx < size; ++idx) {
factor_tensors.push_back(factor_op.output(idx));
old_tensors.push_back(reduce_stage->op.output(idx));
}
Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
[&](const Array<Var>& i) {
Array<Expr> indices;
indices.push_back(repl_red_axis->var);
for (Var v : i) {
indices.push_back(v);
}
return Reduce::make(reduce->combiner,
factor_tensor(indices), {repl_red_axis}, const_true());
}, old_tensor->op->name + ".repl");
Array<Expr> factor_exprs;
for (int idx = 0; idx < size; ++idx) {
factor_exprs.push_back(factor_tensors[idx](indices));
}
Array<Expr> reductions;
Array<IterVar> axis = {repl_red_axis};
Expr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
reductions.push_back(Reduce::make(reduce->combiner,
factor_exprs, axis, cond, idx));
}
return reductions;
}, reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
vmap[old_tensor] = repl_tensor;
for (int idx = 0; idx < size; ++idx) {
vmap[old_tensors[idx]] = repl_tensors[idx];
}
ReplaceDataFlow((*this)->stages, &vmap);
// revamp the reduction stage.
reduce_stage->op = repl_tensor->op;
reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars();
reduce_stage->op = repl_tensors[0]->op;
reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
reduce_stage->relations = Array<IterVarRelation>();
return factor_tensor;
return factor_tensors;
}
} // namespace tvm
......@@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator {
// This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) {
for (int i = 0; i < s->op->num_outputs(); ++i) {
Tensor target = s->origin_op.output(0);
Tensor target = s->origin_op.output(i);
AddReplace(s->op.output(i), target,
target, s->origin_op);
}
......
......@@ -49,7 +49,6 @@ def test_reduce_prims():
test_prim(tvm.max, np.amax)
def test_rfactor():
n = tvm.convert(1027)
A = tvm.placeholder((n,), name='A')
......@@ -128,7 +127,115 @@ def test_rfactor_threads():
check_target("metal")
check_target("opencl")
def test_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine,
fidentity,
name='argmax')
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='float32')
k = tvm.reduce_axis((0, n), 'k')
T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
s = tvm.create_schedule(T0.op)
def check_target():
device = 'cpu'
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[idx, val, T0, T1])
fargmax = tvm.build(fapi,
target='llvm',
name="argmax")
mm = 12
nn = 16
np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype('float32')
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target()
def test_rfactor_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine,
fidentity,
name='argmax')
nn = 1027
mm = 10
n = tvm.convert(nn)
m = tvm.convert(mm)
A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
k = tvm.reduce_axis((0, n))
B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')
# schedule
s = tvm.create_schedule(B0.op)
nthread = 16
ko, kf = s[B0].split(k, factor=nthread)
BF0, BF1 = s.rfactor(B0, kf)
bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
tx = s[B0].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[B0].bind(tx, thread_x)
s[BF0.op].compute_at(s[B0], tx)
s[B0].set_store_predicate(thread_x.var.equal(0))
def check_target(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A0, A1, B0, B1])
fargmax = tvm.build(fapi,
target=device,
name="argmax")
np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype('float32')
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target("cuda")
if __name__ == "__main__":
test_rfactor_threads()
test_rfactor()
test_reduce_prims()
test_argmax()
test_rfactor_argmax()
......@@ -101,8 +101,8 @@ def test_rfactor():
s = tvm.create_schedule(B.op)
BF = s.rfactor(B, k1)
assert(tuple(BF.shape) == (n, n))
assert(set(BF.op.body.axis) == set([k2]))
assert(s[B].op.body.axis[0].dom.extent == n)
assert(set(BF.op.body[0].axis) == set([k2]))
assert(s[B].op.body[0].axis[0].dom.extent == n)
assert(len(s[B].all_iter_vars) == 2)
# schedule with splot
s = tvm.create_schedule(B.op)
......@@ -111,9 +111,9 @@ def test_rfactor():
BF = s.rfactor(B, ki)
assert(BF.shape[0].value == 4)
assert(BF.shape[1] == n)
assert(BF.op.body.axis[0] == k2)
assert(BF.op.body.axis[1].var == ko.var)
assert(s[B].op.body.axis[0].dom.extent.value == 4)
assert(BF.op.body[0].axis[0] == k2)
assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
if __name__ == "__main__":
......
......@@ -118,6 +118,43 @@ def test_extern_multi_out():
assert(len(res) == 2)
assert(res[1].value_index == 1)
def test_tuple_inputs():
m = tvm.var('m')
n = tvm.var('n')
A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1')
T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T')
s = tvm.create_schedule(T0.op)
for i in range(len(T0.shape)):
assert(T0.shape[i] == T1.shape[i])
assert(T0.op == T1.op)
assert(T0.value_index == 0)
assert(T1.value_index == 1)
def test_tuple_with_different_deps():
m = tvm.var('m')
n = tvm.var('n')
A0 = tvm.placeholder((m, n), name='A1')
A1 = tvm.placeholder((m, n), name='A2')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B')
C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=10)
s[B0.op].compute_at(s[C], xo)
sch = s.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
def get_B1_realize(x):
if isinstance(x, tvm.stmt.Realize) and \
x.func == B1.op and x.value_index == 1:
ret.append(x)
ret = []
tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize)
assert stmt.node == C.op and len(ret) == 1
if __name__ == "__main__":
test_conv1d()
......@@ -128,3 +165,5 @@ if __name__ == "__main__":
test_scan_multi_out()
test_extern()
test_extern_multi_out()
test_tuple_inputs()
test_tuple_with_different_deps()
......@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body)
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
......@@ -25,7 +25,7 @@ def test_inline2():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body)
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op):
if isinstance(op, tvm.expr.Call):
assert op.func != T.op
......
......@@ -89,7 +89,7 @@ def test_inline_mixed():
def check(x):
if isinstance(x, tvm.expr.Call):
assert x.func != A2
tvm.ir_pass.PostOrderVisit(s[C].op.body, check)
tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check)
def test_scan_inline1():
......
......@@ -125,6 +125,8 @@ np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
######################################################################
# .. _general-reduction:
#
# Define General Commutative Reduction Operation
# ----------------------------------------------
# Besides the built-in reduction operations like :any:`tvm.sum`,
......@@ -140,6 +142,12 @@ A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')
######################################################################
# .. note::
#
# Sometimes we would like to perform reduction that involves multiple
# values like :code:`argmax`, which can be done by tuple inputs.
# See :ref:`reduction-with-tuple-inputs` for more detail.
######################################################################
# Summary
......
"""
Compute and Reduction with Tuple Inputs
=======================================
**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_
Often we want to compute multiple outputs with the same shape within
a single loop or perform reduction that involves multiple values like
:code:`argmax`. These problems can be addressed by tuple inputs.
In this tutorial, we will introduce the usage of tuple inputs in TVM.
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
######################################################################
# Describe Batchwise Computation
# ------------------------------
# For operators which have the same shape, we can put them together as
# the inputs of :any:`tvm.compute`, if we wish they can be scheduled
# together in the next schedule procedure.
#
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B')
# The generated IR code would be:
s = tvm.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
######################################################################
# .. _reduction-with-tuple-inputs:
#
# Describe Reduction with Collaborative Inputs
# --------------------------------------------
# Sometimes, we requires multiple inputs to express some reduction
# operators, and the inputs will collaborate together, e.g. :code:`argmax`.
# In the reduction procedure, :code:`argmax` need to compare the value of
# operands, also need to keep the index of operand. It can be expressed
# with :any:`comm_reducer` as below:
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts
# two types as inputs.
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax')
# describe the reduction computation
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='int32')
k = tvm.reduce_axis((0, n), 'k')
T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T')
# the generated IR code would be:
s = tvm.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
######################################################################
# .. note::
#
# For ones who are not familiar with reduction, please refer to
# :ref:`general-reduction`.
######################################################################
# Schedule Operation with Tuple Inputs
# ------------------------------------
# It is worth mentioning that although you will get multiple outputs
# with one batch operation, but they can only be scheduled together
# in terms of operation.
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B')
A1 = tvm.placeholder((m, n), name='A1')
C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C')
s = tvm.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
# as you can see in the below generated IR code:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))
######################################################################
# Summary
# -------
# This tutorial introduces the usage of tuple inputs operation.
#
# - Describe normal batchwise computation.
# - Describe reduction operation with tuple inputs.
# - Notice that you can only schedule computation in terms of operation instead of tensor.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment