Commit f467f66e by ziheng Committed by GitHub

Support for Tuple Inputs of Reducer and ComputeOp (#175)

* Support for batch ComputeOp

* Support for batch ComputeOp

* Fix CrossThreadReduction

* Fix lint

* Add UpdateArray, remove support for batch reduce

* Tuple input support for reduce

* rfactor works with multiple reducer; support multiple reducers with different types

* Small fix

* Small fix

* Change return type of rfactor to Array<Expr>

* Fix lint

* Improve

* Add tutorial

* Improve tutorial

* Improve tutorial
parent ef50162b
...@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef { ...@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
* binary operator with identity element * binary operator with identity element
*/ */
struct CommReducerNode : public Node { struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */ /*! \brief The left argument of reducer */
Array<Var> args; Array<Var> lhs;
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */ /*! \brief The result of reducer */
Expr result; Array<Expr> result;
/*! /*!
* \brief The identity element of reducer, which leaves other * \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to * elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses. * the binary operation of this reducer uses.
*/ */
Expr identity_element; Array<Expr> identity_element;
/*! \brief Function call operator to combine a and b */ /*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const; Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */ /*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element); static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args); v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result); v->Visit("result", &result);
v->Visit("identity_element", &identity_element); v->Visit("identity_element", &identity_element);
} }
...@@ -84,7 +88,7 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -84,7 +88,7 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */ /*! \brief The commutative combiner */
CommReducer combiner; CommReducer combiner;
/*! \brief The source operand */ /*! \brief The source operand */
Expr source; Array<Expr> source;
/*! \brief The reduction axis */ /*! \brief The reduction axis */
Array<IterVar> axis; Array<IterVar> axis;
/*! /*!
...@@ -92,18 +96,22 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -92,18 +96,22 @@ struct Reduce : public ExprNode<Reduce> {
* Only add the body to reduction if condition is true. * Only add the body to reduction if condition is true.
*/ */
Expr condition; Expr condition;
/*! \brief the index of this reduce node */
int value_index;
/*! \brief construct expr from op and rdom */ /*! \brief construct expr from op and rdom */
static Expr make(CommReducer combiner, static Expr make(CommReducer combiner,
Expr src, Array<Expr> src,
Array<IterVar> rdom, Array<IterVar> rdom,
Expr condition = const_true()); Expr condition,
int value_index);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("source", &source); v->Visit("source", &source);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("condition", &condition); v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
} }
static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
...@@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; ...@@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
* Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond, * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
* Var thread_idx1, thread_idx2...) { * Var reduce_temp0, .., Var thread_idx1, ...) {
* // constraint by the other thread_idx remain the same. * // constraint by the other thread_idx remain the same.
* return reduce(combiner, value, cond, * // reduce_temp is used to save intermediate result.
* over [thread_idx1, thread_idx2] passed by any caller) * reduce_temp0, ... = reduce(combiner, source0, ..., cond
* over [thread_idx1, thread_idx2] passed by any caller)
* } * }
*/ */
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
......
...@@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map); ...@@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*! /*!
* \brief inline all calls of f in stmt. * \brief inline all calls of f in stmt.
* *
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined * \param f The function reference to be inlined
* \param args The arguments variable of the function. * \param args The arguments variable of the function.
* \param body The defintion body of the function. * \param body The definition body of the function.
* \param stmt The statement to apply inline optimization.
* \return The result stmt * \return The result stmt
* *
* \note All the passes in this file uses SSA form and outputs SSA form. * \note All the passes in this file uses SSA form and outputs SSA form.
......
...@@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode { ...@@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
/*! \brief IterVar on each reduction axis, if the body is a Reduce */ /*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis; Array<IterVar> reduce_axis;
/*! \brief the compute expression */ /*! \brief the compute expression */
Expr body; Array<Expr> body;
/*! \brief constructor */ /*! \brief constructor */
ComputeOpNode() {} ComputeOpNode() {}
// override functions // override functions
...@@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode { ...@@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
} }
static Operation make(std::string name, static Operation make(std::string name,
Array<IterVar> axis, Array<IterVar> axis,
Expr body); Array<Expr> body);
static constexpr const char* _type_key = "ComputeOp"; static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
...@@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode { ...@@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
/*! /*!
* \brief create a place holder tensor. * \brief create a place holder tensor.
* \param shape The shape of the tensor. * \param shape The shape of the tensor.
...@@ -378,6 +381,15 @@ Tensor placeholder(Array<Expr> shape, ...@@ -378,6 +381,15 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*! /*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
*/
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan. * \brief Construct new tensors by scan.
* *
* \param init The intialize tensor of first K steps. * \param init The intialize tensor of first K steps.
......
...@@ -252,15 +252,15 @@ class Schedule : public NodeRef { ...@@ -252,15 +252,15 @@ class Schedule : public NodeRef {
/*! /*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis. * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis * This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction * as the first dimension. The tensor's body will be rewritten as a reduction
* over the factored tensor. * over the factored tensor.
* *
* \param tensor The tensor to be factored. * \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored. * \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensor. * \return The created factored tensors.
*/ */
Tensor rfactor(const Tensor& tensor, Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis); const IterVar& axis);
/*! /*!
* \brief Normalize the schedule. * \brief Normalize the schedule.
* This is needed before bound inference. * This is needed before bound inference.
......
...@@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"): ...@@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"):
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var]) body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
name, dim_var, body) name, dim_var, body)
return op_node.output(0) num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan"): def scan(init, update, state_placeholder, inputs=None, name="scan"):
...@@ -525,18 +529,45 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -525,18 +529,45 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return res return res
def _make_reduce(expr, axis, where=None): def _make_reduce(expr, axis, where=None):
expr = convert(expr)
dtype = expr.dtype
code = fcombine.__code__ code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2 assert fcombine.__code__.co_argcount == 2
arg_vars = [var(name, dtype) for name in code.co_varnames] expr = convert(expr)
result = fcombine(*[v for v in arg_vars]) if isinstance(expr, _collections.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.Expr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result) result = convert(result)
id_elem = fidentity(dtype) id_elem = convert(id_elem)
assert isinstance(id_elem, _expr.Expr) combiner = _make.CommReducer(lhs, rhs, result, id_elem)
combiner = _make.CommReducer(arg_vars, result, id_elem) axis = convert(axis if isinstance(axis, list) else [axis])
axis = axis if isinstance(axis, list) else [axis] if where is None:
return _make.Reduce(combiner, expr, axis, where) where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
def reducer(expr, axis, where=None, *args): def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)): if isinstance(axis, (_schedule.IterVar, list)):
......
...@@ -181,7 +181,7 @@ class Schedule(NodeBase): ...@@ -181,7 +181,7 @@ class Schedule(NodeBase):
""" Factor a reduction axis in tensor's schedule to be an explicit axis. """ Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor. over the factored tensor.
Parameters Parameters
...@@ -193,10 +193,11 @@ class Schedule(NodeBase): ...@@ -193,10 +193,11 @@ class Schedule(NodeBase):
Returns Returns
------- -------
tfactor : Tensor tfactor : Tensor or Array of Tensor
The created factored tensor. The created factored tensor.
""" """
return _api_internal._ScheduleRFactor(self, tensor, axis) factored = _api_internal._ScheduleRFactor(self, tensor, axis)
return factored[0] if len(factored) == 1 else factored
@register_node @register_node
......
...@@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call") ...@@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call")
}); });
TVM_REGISTER_API("make.CommReducer") TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]); *ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
}); });
// make from two arguments // make from two arguments
#define REGISTER_MAKE1(Node) \ #define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
...@@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer") ...@@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(a, b); \ *ret = Node::make(a, b); \
}) })
REGISTER_MAKE4(Reduce); REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt); REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm); REGISTER_MAKE2(IntImm);
......
...@@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) { ...@@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x"), y("y");
Expr result = ir::Add::make(x, y); Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type()); Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); ir::CommReducer combiner =
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
} }
Expr max(Expr source, Array<IterVar> rdom) { Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x"), y("y");
Expr result = ir::Max::make(x, y); Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min(); Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); ir::CommReducer combiner =
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
} }
Expr min(Expr source, Array<IterVar> rdom) { Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y"); Var x("x"), y("y");
Expr result = ir::Min::make(x, y); Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max(); Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); ir::CommReducer combiner =
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
} }
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <ir/IR.h> #include <ir/IR.h>
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <memory> #include <memory>
#include "../pass/ir_util.h"
namespace Halide { namespace Halide {
namespace Internal { namespace Internal {
...@@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { ...@@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner=" p->stream << "reduce(combiner="
<< op->combiner << op->combiner;
<< ", "; p->stream << ", source=" << op->source;
p->print(op->source);
p->stream << ", axis=" << op->axis; p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) { p->stream << ", where=" << op->condition;
p->stream << ", where=" << op->condition; p->stream << ", value_index=" << op->value_index;
}
p->stream << ")"; p->stream << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) { .set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result=" p->stream << "comm_reducer(result=" << op->result
<< op->result << ", lhs=" << op->lhs
<< ", args=" << op->args << ", rhs=" << op->rhs
<< ", identity_element=" << ", identity_element=" << op->identity_element
<< op->identity_element
<< ")"; << ")";
}); });
} // namespace Internal } // namespace Internal
...@@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm { namespace tvm {
namespace ir { namespace ir {
CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) { CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element) {
auto node = std::make_shared<CommReducerNode>(); auto node = std::make_shared<CommReducerNode>();
node->args = args; node->lhs = lhs;
node->rhs = rhs;
node->result = result; node->result = result;
node->identity_element = identity_element; node->identity_element = identity_element;
return CommReducer(node); return CommReducer(node);
} }
Expr CommReducerNode::operator()(Expr a, Expr b) const { Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
Map<Var, Expr> value_map; Map<Var, Expr> value_map;
value_map.Set(args[0], a); for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(args[1], b); value_map.Set(lhs[i], a[i]);
return Substitute(result, value_map); value_map.Set(rhs[i], b[i]);
}
return UpdateArray(result, [&value_map] (const Expr& e) {
return Substitute(e, value_map);
});
} }
Expr Reduce::make(CommReducer combiner, Expr source, Expr Reduce::make(CommReducer combiner, Array<Expr> source,
Array<IterVar> axis, Expr condition) { Array<IterVar> axis, Expr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce) CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis"; << "Can only take axis created by reduce_axis";
...@@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source, ...@@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source,
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined()); CHECK(axis[i].defined());
} }
n->type = source.type(); n->type = source[value_index].type();
n->combiner = combiner; n->combiner = std::move(combiner);
n->source = source; n->source = std::move(source);
n->axis = axis; n->axis = std::move(axis);
n->condition = condition; n->condition = condition;
n->value_index = value_index;
return Expr(n); return Expr(n);
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include "./ir_util.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -17,19 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) ...@@ -17,19 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
} }
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
std::vector<Expr> new_arr(arr.size()); return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); });
bool changed = false;
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = m->Mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<Expr>(new_arr);
}
} }
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
...@@ -323,14 +312,15 @@ DEFINE_BIOP_EXPR_MUTATE_(Or) ...@@ -323,14 +312,15 @@ DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this); Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Expr new_source = this->Mutate(op->source); Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition); Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) && if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source) && op->source.same_as(new_source) &&
op->condition.same_as(new_cond)) { op->condition.same_as(new_cond)) {
return e; return e;
} else { } else {
return Reduce::make(op->combiner, new_source, new_axis, new_cond); return Reduce::make(
op->combiner, new_source, new_axis, new_cond, op->value_index);
} }
} }
......
...@@ -13,6 +13,32 @@ namespace tvm { ...@@ -13,6 +13,32 @@ namespace tvm {
namespace ir { namespace ir {
/*! /*!
* \brief update array with an unary function
* \param arr array
* \param fupdate an unary function
* \tparam T type of array element
* \tparam F type of the unary function
* \return if update happens, return the new array, else return the
* original array
*/
template<typename T, typename F>
inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
std::vector<T> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); ++i) {
T old_elem = arr[i];
T new_elem = fupdate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<T>(new_arr);
}
}
/*!
* \brief combine the nest stmt, whose body is not defined. * \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined. * \param nest A list of For and LetStmt, whose body is not defined.
* \param body body * \param body body
......
...@@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or) ...@@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) { void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this); VisitRDom(op->axis, this);
this->Visit(op->source); VisitArray(op->source, this);
} }
void IRVisitor::Visit_(const Cast* op) { void IRVisitor::Visit_(const Cast* op) {
......
...@@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
} }
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>(); op = stmt.as<Evaluate>();
const Call* call = op->value.as<Call>(); const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
return MakeAllreduce(op, call); return MakeAllreduce(call);
} else { } else {
return stmt; return stmt;
} }
...@@ -97,18 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -97,18 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator {
} }
}; };
// make allreduce. // make allreduce.
Stmt MakeAllreduce(const Store* op, const Call* call) { Stmt MakeAllreduce(const Call* call) {
CHECK(!reduce_combiner_.empty()); CHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back(); const CommReducerNode *combiner = reduce_combiner_.back();
Expr init = combiner->identity_element; size_t size = combiner->result.size();
Expr value = call->args[0];
Expr cond = call->args[1]; const UIntImm *size_of_args = call->args[0].as<UIntImm>();
if (!is_one(cond)) { CHECK(size_of_args) << call->args[0]->type_key();
value = Select::make(cond, value, init); CHECK_EQ(size, size_of_args->value);
Array<Expr> inits = combiner->identity_element;
std::vector<Expr> values(size);
std::vector<Type> types(size);
Expr cond = call->args[size+1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1+idx];
if (!is_one(cond)) {
values[idx] = Select::make(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].type();
}
std::vector<const Variable*> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
const Variable* buffer = call->args[2+size+idx].as<Variable>();
CHECK(buffer);
buffers[idx] = buffer;
} }
std::unordered_set<const Variable*> reduce_set; std::unordered_set<const Variable*> reduce_set;
for (size_t i = 2; i < call->args.size(); ++i) { for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>(); const Variable* v = call->args[i].as<Variable>();
CHECK(v); CHECK(v);
reduce_set.insert(v); reduce_set.insert(v);
...@@ -143,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -143,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator {
int threadx_extent = 1; int threadx_extent = 1;
Expr reduce_index = FlattenThread(vred, &reduce_extent); Expr reduce_index = FlattenThread(vred, &reduce_extent);
Expr group_index = FlattenThread(vpar, &group_extent); Expr group_index = FlattenThread(vpar, &group_extent);
Expr pred = const_true(value.type().lanes());
if (reduce_extent == 1) { if (reduce_extent == 1) {
// special case, no reduction is needed. // special case, no reduction is needed.
return Store::make(op->buffer_var, value, 0, pred); std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
Expr pred = const_true(types[i].lanes());
Var buffer_var(call->args[2+size+i].node_);
stores[i] = Store::make(buffer_var, values[i], 0, pred);
}
return Block::make(stores);
} }
// Whether the threadIdx.x is involved in reduction. // Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) { if (vred[0].scope.dim_index == 0) {
threadx_extent = vred[0].extent; threadx_extent = vred[0].extent;
} }
Var shared_buf("red_buf", Handle());
std::vector<Stmt> seq; std::vector<Stmt> seq;
seq.emplace_back(Store::make( std::vector<Var> shared_bufs(size);
shared_buf, value, for (size_t idx = 0; idx < size; ++idx) {
BufIndex(reduce_index, group_index, reduce_extent), pred)); shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
Expr pred = const_true(types[idx].lanes());
seq.emplace_back(Store::make(
shared_bufs[idx], values[idx],
BufIndex(reduce_index, group_index, reduce_extent), pred));
}
seq.emplace_back(SyncThread("shared")); seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce( seq.emplace_back(MakeBufAllreduce(
combiner, value.type(), shared_buf, combiner, types, shared_bufs,
reduce_index, group_index, reduce_extent, threadx_extent)); reduce_index, group_index, reduce_extent, threadx_extent));
CHECK(!load_remap_.count(op->buffer_var.get())); for (size_t idx = 0; idx < size; ++idx) {
load_remap_[op->buffer_var.get()] = CHECK(!load_remap_.count(buffers[idx]));
Load::make( Expr pred = const_true(types[idx].lanes());
value.type(), shared_buf, load_remap_[buffers[idx]] = Load::make(
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), types[idx], shared_bufs[idx],
pred); BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
alloc_remap_[op->buffer_var.get()] = alloc_remap_[buffers[idx]] = Allocate::make(
Allocate::make(shared_buf, value.type(), shared_bufs[idx], types[idx],
{Expr(group_extent), Expr(reduce_extent)}, {Expr(group_extent), Expr(reduce_extent)},
pred, Evaluate::make(0)); pred, Evaluate::make(0));
}
return MergeSeq(seq); return MergeSeq(seq);
} }
// make allreduce. // make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner, Stmt MakeBufAllreduce(const CommReducerNode *combiner,
Type type, const std::vector<Type>& types,
Var shared_buf, const Array<Var>& shared_bufs,
Expr reduce_index, Expr reduce_index,
Expr group_index, Expr group_index,
int reduce_extent, int reduce_extent,
...@@ -189,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -189,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GT(reduce_align, 1); CHECK_GT(reduce_align, 1);
std::vector<Stmt> seq; std::vector<Stmt> seq;
size_t size = shared_bufs.size();
Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent); Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction // make reduction
auto freduce = [&](int offset) { auto freduce = [&](int offset) {
Expr b = Load::make( Array<Expr> a, b;
type, shared_buf, for (size_t i = 0; i < size; ++i) {
BufIndex(reduce_index + offset, group_index, reduce_extent), const_true()); b.push_back(Load::make(types[i], shared_bufs[i],
Expr a = Load::make(type, shared_buf, buf_index, const_true()); BufIndex(reduce_index + offset, group_index, reduce_extent),
return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true()); const_true()));
a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
}
Array<Expr> ret = (*combiner)(a, b);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
}
return Block::make(stores);
}; };
// Step one, check for // Step one, check for
if (reduce_align > reduce_extent) { if (reduce_align > reduce_extent) {
......
...@@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator { ...@@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator {
CHECK_EQ(extern_buf_remap_.size(), 0U); CHECK_EQ(extern_buf_remap_.size(), 0U);
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) { for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
TensorKey key{func, static_cast<int>(i)}; TensorKey key{func, static_cast<int>(i)};
CHECK(buf_map_.count(key)); CHECK(buf_map_.count(key))
<< "Cannot find allocated buffer for " << key.f
<< "(" << key.value_index << ")";
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] = extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data; buf_map_.at(key).buffer->data;
} }
......
...@@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor { ...@@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor {
bool IsElemWise(const Operation& op) { bool IsElemWise(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) { if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis); ElemWiseDetector v = ElemWiseDetector(compute->axis);
v.Visit(compute->body); for (auto& e : compute->body) v.Visit(e);
return v.is_elem_wise_; return v.is_elem_wise_;
} }
return false; return false;
......
...@@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
} }
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); for (auto& e : op.as<ComputeOpNode>()->body) {
ir::PostOrderVisit(e, fvisit);
}
} }
} }
return reach; return reach;
...@@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { ...@@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
} }
} }
} else if (op.as<ComputeOpNode>()) { } else if (op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap; std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
const auto& axis = op.as<ComputeOpNode>()->axis; const auto& axis = op.as<ComputeOpNode>()->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i); std::vector<TensorDimKey> keys;
for (int j = 0; j < op->num_outputs(); ++j) {
keys.emplace_back(op.output(j), i);
}
vmap[axis[i]->var.get()] = std::move(keys);
} }
auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
const NodeRef& n) { const NodeRef& n) {
...@@ -335,7 +340,10 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { ...@@ -335,7 +340,10 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
auto it = vmap.find(call->args[i].get()); auto it = vmap.find(call->args[i].get());
TensorDimKey src(call, static_cast<int>(i)); TensorDimKey src(call, static_cast<int>(i));
if (it != vmap.end()) { if (it != vmap.end()) {
f_merge_key(it->second, src); const std::vector<TensorDimKey>& keys = it->second;
for (const auto& key : keys) {
f_merge_key(key, src);
}
} else { } else {
if (exact_reach.count(src)) { if (exact_reach.count(src)) {
fail_set.insert(exact_reach.at(src)); fail_set.insert(exact_reach.at(src));
...@@ -344,7 +352,9 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { ...@@ -344,7 +352,9 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
} }
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); for (auto& e : op.as<ComputeOpNode>()->body) {
ir::PostOrderVisit(e, fvisit);
}
} }
} }
ReachGraph reach; ReachGraph reach;
......
...@@ -27,7 +27,7 @@ using ReadGraph = Map<Operation, Array<Tensor> >; ...@@ -27,7 +27,7 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >; using AttachPath = Map<Operation, Array<IterVar> >;
/*! /*!
* \brief The map beteen tensor and operation it feeds to. * \brief The map between tensor and operation it feeds to.
*/ */
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >; using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
...@@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots); ...@@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots);
* The operations contains node which input-reachable from any inputs * The operations contains node which input-reachable from any inputs
* output reachable to any outputs. * output reachable to any outputs.
* *
* The inputs won't be included in the subgraph, the outputs will be inclued. * The inputs won't be included in the subgraph, the outputs will be included.
* *
* \param outputs The outputs of the subgraph * \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph. * \param inputs The inputs to the subgraph.
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "./message_passing.h" #include "./message_passing.h"
#include "../pass/ir_util.h"
namespace tvm { namespace tvm {
...@@ -120,13 +121,13 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -120,13 +121,13 @@ Tensor Schedule::cache_write(const Tensor& tensor,
vsub[iv->var.get()] = new_iv->var; vsub[iv->var.get()] = new_iv->var;
} }
VarReplacer repl(vsub); VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body); Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make( Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, body); compute->name + "." + scope, new_axis, {body});
Tensor cache_tensor = cache_op.output(0); Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make( Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis, compute->name, compute->axis,
cache_tensor(args)); {cache_tensor(args)});
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0); vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
...@@ -198,14 +199,15 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -198,14 +199,15 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
void InjectInline(ScheduleNode* sch) { void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache(); sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size()); std::vector<Array<Expr>> new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
// inline all the ops // inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1]; Stage stage = sch->stages[i - 1];
if (stage->attach_type == kInline) { if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready; stage->attach_type = kInlinedAlready;
Array<Var> args; Array<Var> args;
Expr body; Array<Expr> body;
{ {
// setup args // setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>(); const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
...@@ -220,11 +222,14 @@ void InjectInline(ScheduleNode* sch) { ...@@ -220,11 +222,14 @@ void InjectInline(ScheduleNode* sch) {
Stage s = sch->stages[j]; Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) { if (compute) {
if (!new_body[j].defined()) { if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body; new_body[j] = s->op.as<ComputeOpNode>()->body;
} }
new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]), for (size_t k = 0; k < body.size(); ++k) {
stage->op, args, body).as<ir::Evaluate>()->value; changed[j] = true;
new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
stage->op, args, body[k]).as<ir::Evaluate>()->value);
}
} }
} }
} }
...@@ -234,19 +239,21 @@ void InjectInline(ScheduleNode* sch) { ...@@ -234,19 +239,21 @@ void InjectInline(ScheduleNode* sch) {
for (size_t i = 0; i < sch->stages.size(); ++i) { for (size_t i = 0; i < sch->stages.size(); ++i) {
Stage s = sch->stages[i]; Stage s = sch->stages[i];
if (s->attach_type == kInlinedAlready) continue; if (s->attach_type == kInlinedAlready) continue;
if (new_body[i].defined()) { if (new_body[i].size()) {
// Logics from ReplaceDataFlow // Logics from ReplaceDataFlow
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>(); const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute); CHECK(compute);
Operation op = s->op; Operation op = s->op;
if (!new_body[i].same_as(compute->body)) { if (changed[i]) {
op = ComputeOpNode::make( op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]); compute->name, compute->axis, new_body[i]);
} }
op = op->ReplaceInputs(op, repl); op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) { if (!op.same_as(s->op)) {
repl[s->op.output(0)] = op.output(0); for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
s->op = op; repl[s->op.output(idx)] = op.output(idx);
s->op = op;
}
} }
} else { } else {
Operation op = s->op->ReplaceInputs(s->op, repl); Operation op = s->op->ReplaceInputs(s->op, repl);
...@@ -268,15 +275,15 @@ Schedule Schedule::normalize() { ...@@ -268,15 +275,15 @@ Schedule Schedule::normalize() {
} }
// Handle reduction factor. // Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor, Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) { const IterVar& axis) {
(*this)->InvalidateCache(); (*this)->InvalidateCache();
using ir::Reduce; using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce) CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis"; << "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op); Stage reduce_stage = operator[](tensor->op);
const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>(); const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
CHECK(compute_op) << "Can only factor ComputeOp"; CHECK(compute_op) << "Can only factor ComputeOp";
ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
{ {
size_t axis_pos = FindNodeRef(leaf_vars, axis); size_t axis_pos = FindNodeRef(leaf_vars, axis);
...@@ -329,7 +336,8 @@ Tensor Schedule::rfactor(const Tensor& tensor, ...@@ -329,7 +336,8 @@ Tensor Schedule::rfactor(const Tensor& tensor,
} }
} }
// predicate generation, copy not touched axis. // predicate generation, copy not touched axis.
const Reduce* reduce = compute_op->body.as<Reduce>(); int idx = tensor->value_index;
const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions"; CHECK(reduce) << "Can only rfactor non-inline reductions";
Expr predicate = reduce->condition; Expr predicate = reduce->condition;
std::unordered_map<const Variable*, Expr> vsub; std::unordered_map<const Variable*, Expr> vsub;
...@@ -359,10 +367,18 @@ Tensor Schedule::rfactor(const Tensor& tensor, ...@@ -359,10 +367,18 @@ Tensor Schedule::rfactor(const Tensor& tensor,
n->reduce_axis.push_back(IterVar(ncpy)); n->reduce_axis.push_back(IterVar(ncpy));
} }
} }
n->body = Reduce::make(reduce->combiner, VarReplacer replacer(vsub);
VarReplacer(vsub).Mutate(reduce->source), Array<Expr> new_source = ir::UpdateArray(reduce->source,
n->reduce_axis, [&replacer] (const Expr& e) { return replacer.Mutate(e); });
predicate); std::vector<Expr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
body.emplace_back(Reduce::make(reduce->combiner,
new_source,
n->reduce_axis,
predicate,
idx));
}
n->body = Array<Expr>(body);
// refresh relations, keep the un-touched relations. // refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels; Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) { for (IterVarRelation rel : reduce_stage->relations) {
...@@ -397,26 +413,44 @@ Tensor Schedule::rfactor(const Tensor& tensor, ...@@ -397,26 +413,44 @@ Tensor Schedule::rfactor(const Tensor& tensor,
// Replace the old reduction. // Replace the old reduction.
IterVar repl_red_axis = reduce_axis( IterVar repl_red_axis = reduce_axis(
dom_map.at(axis), axis->var->name_hint + ".v"); dom_map.at(axis), axis->var->name_hint + ".v");
Tensor factor_tensor = factor_op.output(0); Array<Tensor> factor_tensors;
Tensor old_tensor = reduce_stage->op.output(0); Array<Tensor> old_tensors;
Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) { int size = factor_op->num_outputs();
for (int idx = 0; idx < size; ++idx) {
factor_tensors.push_back(factor_op.output(idx));
old_tensors.push_back(reduce_stage->op.output(idx));
}
Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
[&](const Array<Var>& i) {
Array<Expr> indices; Array<Expr> indices;
indices.push_back(repl_red_axis->var); indices.push_back(repl_red_axis->var);
for (Var v : i) { for (Var v : i) {
indices.push_back(v); indices.push_back(v);
} }
return Reduce::make(reduce->combiner, Array<Expr> factor_exprs;
factor_tensor(indices), {repl_red_axis}, const_true()); for (int idx = 0; idx < size; ++idx) {
}, old_tensor->op->name + ".repl"); factor_exprs.push_back(factor_tensors[idx](indices));
}
Array<Expr> reductions;
Array<IterVar> axis = {repl_red_axis};
Expr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
reductions.push_back(Reduce::make(reduce->combiner,
factor_exprs, axis, cond, idx));
}
return reductions;
}, reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
vmap[old_tensor] = repl_tensor; for (int idx = 0; idx < size; ++idx) {
vmap[old_tensors[idx]] = repl_tensors[idx];
}
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
// revamp the reduction stage. // revamp the reduction stage.
reduce_stage->op = repl_tensor->op; reduce_stage->op = repl_tensors[0]->op;
reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars(); reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
reduce_stage->relations = Array<IterVarRelation>(); reduce_stage->relations = Array<IterVarRelation>();
return factor_tensor; return factor_tensors;
} }
} // namespace tvm } // namespace tvm
...@@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator { ...@@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator {
// This must be checked for all ops, including scan. // This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) { if (!s->op.same_as(s->origin_op)) {
for (int i = 0; i < s->op->num_outputs(); ++i) { for (int i = 0; i < s->op->num_outputs(); ++i) {
Tensor target = s->origin_op.output(0); Tensor target = s->origin_op.output(i);
AddReplace(s->op.output(i), target, AddReplace(s->op.output(i), target,
target, s->origin_op); target, s->origin_op);
} }
......
...@@ -49,7 +49,6 @@ def test_reduce_prims(): ...@@ -49,7 +49,6 @@ def test_reduce_prims():
test_prim(tvm.max, np.amax) test_prim(tvm.max, np.amax)
def test_rfactor(): def test_rfactor():
n = tvm.convert(1027) n = tvm.convert(1027)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
...@@ -128,7 +127,115 @@ def test_rfactor_threads(): ...@@ -128,7 +127,115 @@ def test_rfactor_threads():
check_target("metal") check_target("metal")
check_target("opencl") check_target("opencl")
def test_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine,
fidentity,
name='argmax')
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='float32')
k = tvm.reduce_axis((0, n), 'k')
T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
s = tvm.create_schedule(T0.op)
def check_target():
device = 'cpu'
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[idx, val, T0, T1])
fargmax = tvm.build(fapi,
target='llvm',
name="argmax")
mm = 12
nn = 16
np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype('float32')
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target()
def test_rfactor_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine,
fidentity,
name='argmax')
nn = 1027
mm = 10
n = tvm.convert(nn)
m = tvm.convert(mm)
A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
k = tvm.reduce_axis((0, n))
B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')
# schedule
s = tvm.create_schedule(B0.op)
nthread = 16
ko, kf = s[B0].split(k, factor=nthread)
BF0, BF1 = s.rfactor(B0, kf)
bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
tx = s[B0].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[B0].bind(tx, thread_x)
s[BF0.op].compute_at(s[B0], tx)
s[B0].set_store_predicate(thread_x.var.equal(0))
def check_target(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A0, A1, B0, B1])
fargmax = tvm.build(fapi,
target=device,
name="argmax")
np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype('float32')
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target("cuda")
if __name__ == "__main__": if __name__ == "__main__":
test_rfactor_threads() test_rfactor_threads()
test_rfactor() test_rfactor()
test_reduce_prims() test_reduce_prims()
test_argmax()
test_rfactor_argmax()
...@@ -101,8 +101,8 @@ def test_rfactor(): ...@@ -101,8 +101,8 @@ def test_rfactor():
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
BF = s.rfactor(B, k1) BF = s.rfactor(B, k1)
assert(tuple(BF.shape) == (n, n)) assert(tuple(BF.shape) == (n, n))
assert(set(BF.op.body.axis) == set([k2])) assert(set(BF.op.body[0].axis) == set([k2]))
assert(s[B].op.body.axis[0].dom.extent == n) assert(s[B].op.body[0].axis[0].dom.extent == n)
assert(len(s[B].all_iter_vars) == 2) assert(len(s[B].all_iter_vars) == 2)
# schedule with splot # schedule with splot
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
...@@ -111,9 +111,9 @@ def test_rfactor(): ...@@ -111,9 +111,9 @@ def test_rfactor():
BF = s.rfactor(B, ki) BF = s.rfactor(B, ki)
assert(BF.shape[0].value == 4) assert(BF.shape[0].value == 4)
assert(BF.shape[1] == n) assert(BF.shape[1] == n)
assert(BF.op.body.axis[0] == k2) assert(BF.op.body[0].axis[0] == k2)
assert(BF.op.body.axis[1].var == ko.var) assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body.axis[0].dom.extent.value == 4) assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -118,6 +118,43 @@ def test_extern_multi_out(): ...@@ -118,6 +118,43 @@ def test_extern_multi_out():
assert(len(res) == 2) assert(len(res) == 2)
assert(res[1].value_index == 1) assert(res[1].value_index == 1)
def test_tuple_inputs():
m = tvm.var('m')
n = tvm.var('n')
A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1')
T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T')
s = tvm.create_schedule(T0.op)
for i in range(len(T0.shape)):
assert(T0.shape[i] == T1.shape[i])
assert(T0.op == T1.op)
assert(T0.value_index == 0)
assert(T1.value_index == 1)
def test_tuple_with_different_deps():
m = tvm.var('m')
n = tvm.var('n')
A0 = tvm.placeholder((m, n), name='A1')
A1 = tvm.placeholder((m, n), name='A2')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B')
C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=10)
s[B0.op].compute_at(s[C], xo)
sch = s.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
def get_B1_realize(x):
if isinstance(x, tvm.stmt.Realize) and \
x.func == B1.op and x.value_index == 1:
ret.append(x)
ret = []
tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize)
assert stmt.node == C.op and len(ret) == 1
if __name__ == "__main__": if __name__ == "__main__":
test_conv1d() test_conv1d()
...@@ -128,3 +165,5 @@ if __name__ == "__main__": ...@@ -128,3 +165,5 @@ if __name__ == "__main__":
test_scan_multi_out() test_scan_multi_out()
test_extern() test_extern()
test_extern_multi_out() test_extern_multi_out()
test_tuple_inputs()
test_tuple_with_different_deps()
...@@ -6,7 +6,7 @@ def test_inline(): ...@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body) stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
...@@ -25,7 +25,7 @@ def test_inline2(): ...@@ -25,7 +25,7 @@ def test_inline2():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100]) stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body) stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op): def check(op):
if isinstance(op, tvm.expr.Call): if isinstance(op, tvm.expr.Call):
assert op.func != T.op assert op.func != T.op
......
...@@ -89,7 +89,7 @@ def test_inline_mixed(): ...@@ -89,7 +89,7 @@ def test_inline_mixed():
def check(x): def check(x):
if isinstance(x, tvm.expr.Call): if isinstance(x, tvm.expr.Call):
assert x.func != A2 assert x.func != A2
tvm.ir_pass.PostOrderVisit(s[C].op.body, check) tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check)
def test_scan_inline1(): def test_scan_inline1():
......
...@@ -125,6 +125,8 @@ np.testing.assert_allclose( ...@@ -125,6 +125,8 @@ np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
###################################################################### ######################################################################
# .. _general-reduction:
#
# Define General Commutative Reduction Operation # Define General Commutative Reduction Operation
# ---------------------------------------------- # ----------------------------------------------
# Besides the built-in reduction operations like :any:`tvm.sum`, # Besides the built-in reduction operations like :any:`tvm.sum`,
...@@ -140,6 +142,12 @@ A = tvm.placeholder((n, m), name='A') ...@@ -140,6 +142,12 @@ A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k') k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B') B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')
######################################################################
# .. note::
#
# Sometimes we would like to perform reduction that involves multiple
# values like :code:`argmax`, which can be done by tuple inputs.
# See :ref:`reduction-with-tuple-inputs` for more detail.
###################################################################### ######################################################################
# Summary # Summary
......
"""
Compute and Reduction with Tuple Inputs
=======================================
**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_
Often we want to compute multiple outputs with the same shape within
a single loop or perform reduction that involves multiple values like
:code:`argmax`. These problems can be addressed by tuple inputs.
In this tutorial, we will introduce the usage of tuple inputs in TVM.
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
######################################################################
# Describe Batchwise Computation
# ------------------------------
# For operators which have the same shape, we can put them together as
# the inputs of :any:`tvm.compute`, if we wish they can be scheduled
# together in the next schedule procedure.
#
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B')
# The generated IR code would be:
s = tvm.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
######################################################################
# .. _reduction-with-tuple-inputs:
#
# Describe Reduction with Collaborative Inputs
# --------------------------------------------
# Sometimes, we requires multiple inputs to express some reduction
# operators, and the inputs will collaborate together, e.g. :code:`argmax`.
# In the reduction procedure, :code:`argmax` need to compare the value of
# operands, also need to keep the index of operand. It can be expressed
# with :any:`comm_reducer` as below:
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts
# two types as inputs.
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax')
# describe the reduction computation
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='int32')
k = tvm.reduce_axis((0, n), 'k')
T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T')
# the generated IR code would be:
s = tvm.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
######################################################################
# .. note::
#
# For ones who are not familiar with reduction, please refer to
# :ref:`general-reduction`.
######################################################################
# Schedule Operation with Tuple Inputs
# ------------------------------------
# It is worth mentioning that although you will get multiple outputs
# with one batch operation, but they can only be scheduled together
# in terms of operation.
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B')
A1 = tvm.placeholder((m, n), name='A1')
C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C')
s = tvm.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
# as you can see in the below generated IR code:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))
######################################################################
# Summary
# -------
# This tutorial introduces the usage of tuple inputs operation.
#
# - Describe normal batchwise computation.
# - Describe reduction operation with tuple inputs.
# - Notice that you can only schedule computation in terms of operation instead of tensor.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment