Commit f467f66e by ziheng Committed by GitHub

Support for Tuple Inputs of Reducer and ComputeOp (#175)

* Support for batch ComputeOp

* Support for batch ComputeOp

* Fix CrossThreadReduction

* Fix lint

* Add UpdateArray, remove support for batch reduce

* Tuple input support for reduce

* rfactor works with multiple reducer; support multiple reducers with different types

* Small fix

* Small fix

* Change return type of rfactor to Array<Expr>

* Fix lint

* Improve

* Add tutorial

* Improve tutorial

* Improve tutorial
parent ef50162b
......@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */
Array<Var> args;
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */
Expr result;
Array<Expr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
Expr identity_element;
Array<Expr> identity_element;
/*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const;
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
......@@ -84,7 +88,7 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
Expr source;
Array<Expr> source;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
......@@ -92,18 +96,22 @@ struct Reduce : public ExprNode<Reduce> {
* Only add the body to reduction if condition is true.
*/
Expr condition;
/*! \brief the index of this reduce node */
int value_index;
/*! \brief construct expr from op and rdom */
static Expr make(CommReducer combiner,
Expr src,
Array<Expr> src,
Array<IterVar> rdom,
Expr condition = const_true());
Expr condition,
int value_index);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("source", &source);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
......@@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*!
* \brief See pesudo code
*
* Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond,
* Var thread_idx1, thread_idx2...) {
* void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
* Var reduce_temp0, .., Var thread_idx1, ...) {
* // constraint by the other thread_idx remain the same.
* return reduce(combiner, value, cond,
* over [thread_idx1, thread_idx2] passed by any caller)
* // reduce_temp is used to save intermediate result.
* reduce_temp0, ... = reduce(combiner, source0, ..., cond
* over [thread_idx1, thread_idx2] passed by any caller)
* }
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
......
......@@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
......
......@@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
Array<Expr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
......@@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
}
static Operation make(std::string name,
Array<IterVar> axis,
Expr body);
Array<Expr> body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
......@@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
......@@ -378,6 +381,15 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
*/
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan.
*
* \param init The intialize tensor of first K steps.
......
......@@ -252,15 +252,15 @@ class Schedule : public NodeRef {
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction
* as the first dimension. The tensor's body will be rewritten as a reduction
* over the factored tensor.
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensor.
* \return The created factored tensors.
*/
Tensor rfactor(const Tensor& tensor,
const IterVar& axis);
Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
......
......@@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"):
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return op_node.output(0)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan"):
......@@ -525,18 +529,45 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return res
def _make_reduce(expr, axis, where=None):
expr = convert(expr)
dtype = expr.dtype
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
arg_vars = [var(name, dtype) for name in code.co_varnames]
result = fcombine(*[v for v in arg_vars])
expr = convert(expr)
if isinstance(expr, _collections.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.Expr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = fidentity(dtype)
assert isinstance(id_elem, _expr.Expr)
combiner = _make.CommReducer(arg_vars, result, id_elem)
axis = axis if isinstance(axis, list) else [axis]
return _make.Reduce(combiner, expr, axis, where)
id_elem = convert(id_elem)
combiner = _make.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, list) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)):
......
......@@ -181,7 +181,7 @@ class Schedule(NodeBase):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction
as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor.
Parameters
......@@ -193,10 +193,11 @@ class Schedule(NodeBase):
Returns
-------
tfactor : Tensor
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
return _api_internal._ScheduleRFactor(self, tensor, axis)
factored = _api_internal._ScheduleRFactor(self, tensor, axis)
return factored[0] if len(factored) == 1 else factored
@register_node
......
......@@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call")
});
TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \
......@@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(a, b); \
})
REGISTER_MAKE4(Reduce);
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
......
......@@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
......
......@@ -9,6 +9,7 @@
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
#include "../pass/ir_util.h"
namespace Halide {
namespace Internal {
......@@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
<< op->combiner
<< ", ";
p->print(op->source);
<< op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) {
p->stream << ", where=" << op->condition;
}
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result="
<< op->result
<< ", args=" << op->args
<< ", identity_element="
<< op->identity_element
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
<< ", rhs=" << op->rhs
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
......@@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {
CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) {
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element) {
auto node = std::make_shared<CommReducerNode>();
node->args = args;
node->lhs = lhs;
node->rhs = rhs;
node->result = result;
node->identity_element = identity_element;
return CommReducer(node);
}
Expr CommReducerNode::operator()(Expr a, Expr b) const {
Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
Map<Var, Expr> value_map;
value_map.Set(args[0], a);
value_map.Set(args[1], b);
return Substitute(result, value_map);
for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
return UpdateArray(result, [&value_map] (const Expr& e) {
return Substitute(e, value_map);
});
}
Expr Reduce::make(CommReducer combiner, Expr source,
Array<IterVar> axis, Expr condition) {
Expr Reduce::make(CommReducer combiner, Array<Expr> source,
Array<IterVar> axis, Expr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
......@@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source,
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->combiner = combiner;
n->source = source;
n->axis = axis;
n->type = source[value_index].type();
n->combiner = std::move(combiner);
n->source = std::move(source);
n->axis = std::move(axis);
n->condition = condition;
n->value_index = value_index;
return Expr(n);
}
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
namespace tvm {
namespace ir {
......@@ -17,19 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
std::vector<Expr> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = m->Mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<Expr>(new_arr);
}
return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); });
}
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
......@@ -323,14 +312,15 @@ DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Expr new_source = this->Mutate(op->source);
Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source) &&
op->condition.same_as(new_cond)) {
return e;
} else {
return Reduce::make(op->combiner, new_source, new_axis, new_cond);
return Reduce::make(
op->combiner, new_source, new_axis, new_cond, op->value_index);
}
}
......
......@@ -13,6 +13,32 @@ namespace tvm {
namespace ir {
/*!
* \brief update array with an unary function
* \param arr array
* \param fupdate an unary function
* \tparam T type of array element
* \tparam F type of the unary function
* \return if update happens, return the new array, else return the
* original array
*/
template<typename T, typename F>
inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
std::vector<T> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); ++i) {
T old_elem = arr[i];
T new_elem = fupdate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<T>(new_arr);
}
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
......
......@@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
this->Visit(op->source);
VisitArray(op->source, this);
}
void IRVisitor::Visit_(const Cast* op) {
......
......@@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
op = stmt.as<Evaluate>();
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
return MakeAllreduce(op, call);
return MakeAllreduce(call);
} else {
return stmt;
}
......@@ -97,18 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator {
}
};
// make allreduce.
Stmt MakeAllreduce(const Store* op, const Call* call) {
Stmt MakeAllreduce(const Call* call) {
CHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back();
Expr init = combiner->identity_element;
Expr value = call->args[0];
Expr cond = call->args[1];
if (!is_one(cond)) {
value = Select::make(cond, value, init);
size_t size = combiner->result.size();
const UIntImm *size_of_args = call->args[0].as<UIntImm>();
CHECK(size_of_args) << call->args[0]->type_key();
CHECK_EQ(size, size_of_args->value);
Array<Expr> inits = combiner->identity_element;
std::vector<Expr> values(size);
std::vector<Type> types(size);
Expr cond = call->args[size+1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1+idx];
if (!is_one(cond)) {
values[idx] = Select::make(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].type();
}
std::vector<const Variable*> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
const Variable* buffer = call->args[2+size+idx].as<Variable>();
CHECK(buffer);
buffers[idx] = buffer;
}
std::unordered_set<const Variable*> reduce_set;
for (size_t i = 2; i < call->args.size(); ++i) {
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>();
CHECK(v);
reduce_set.insert(v);
......@@ -143,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator {
int threadx_extent = 1;
Expr reduce_index = FlattenThread(vred, &reduce_extent);
Expr group_index = FlattenThread(vpar, &group_extent);
Expr pred = const_true(value.type().lanes());
if (reduce_extent == 1) {
// special case, no reduction is needed.
return Store::make(op->buffer_var, value, 0, pred);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
Expr pred = const_true(types[i].lanes());
Var buffer_var(call->args[2+size+i].node_);
stores[i] = Store::make(buffer_var, values[i], 0, pred);
}
return Block::make(stores);
}
// Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) {
threadx_extent = vred[0].extent;
}
Var shared_buf("red_buf", Handle());
std::vector<Stmt> seq;
seq.emplace_back(Store::make(
shared_buf, value,
BufIndex(reduce_index, group_index, reduce_extent), pred));
std::vector<Var> shared_bufs(size);
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
Expr pred = const_true(types[idx].lanes());
seq.emplace_back(Store::make(
shared_bufs[idx], values[idx],
BufIndex(reduce_index, group_index, reduce_extent), pred));
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(
combiner, value.type(), shared_buf,
combiner, types, shared_bufs,
reduce_index, group_index, reduce_extent, threadx_extent));
CHECK(!load_remap_.count(op->buffer_var.get()));
load_remap_[op->buffer_var.get()] =
Load::make(
value.type(), shared_buf,
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent),
pred);
alloc_remap_[op->buffer_var.get()] =
Allocate::make(shared_buf, value.type(),
{Expr(group_extent), Expr(reduce_extent)},
pred, Evaluate::make(0));
for (size_t idx = 0; idx < size; ++idx) {
CHECK(!load_remap_.count(buffers[idx]));
Expr pred = const_true(types[idx].lanes());
load_remap_[buffers[idx]] = Load::make(
types[idx], shared_bufs[idx],
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
alloc_remap_[buffers[idx]] = Allocate::make(
shared_bufs[idx], types[idx],
{Expr(group_extent), Expr(reduce_extent)},
pred, Evaluate::make(0));
}
return MergeSeq(seq);
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
Type type,
Var shared_buf,
const std::vector<Type>& types,
const Array<Var>& shared_bufs,
Expr reduce_index,
Expr group_index,
int reduce_extent,
......@@ -189,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GT(reduce_align, 1);
std::vector<Stmt> seq;
size_t size = shared_bufs.size();
Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto freduce = [&](int offset) {
Expr b = Load::make(
type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent), const_true());
Expr a = Load::make(type, shared_buf, buf_index, const_true());
return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true());
Array<Expr> a, b;
for (size_t i = 0; i < size; ++i) {
b.push_back(Load::make(types[i], shared_bufs[i],
BufIndex(reduce_index + offset, group_index, reduce_extent),
const_true()));
a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
}
Array<Expr> ret = (*combiner)(a, b);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
}
return Block::make(stores);
};
// Step one, check for
if (reduce_align > reduce_extent) {
......
......@@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator {
CHECK_EQ(extern_buf_remap_.size(), 0U);
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
TensorKey key{func, static_cast<int>(i)};
CHECK(buf_map_.count(key));
CHECK(buf_map_.count(key))
<< "Cannot find allocated buffer for " << key.f
<< "(" << key.value_index << ")";
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
}
......
......@@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor {
bool IsElemWise(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis);
v.Visit(compute->body);
for (auto& e : compute->body) v.Visit(e);
return v.is_elem_wise_;
}
return false;
......
......@@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
for (auto& e : op.as<ComputeOpNode>()->body) {
ir::PostOrderVisit(e, fvisit);
}
}
}
return reach;
......@@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
} else if (op.as<ComputeOpNode>()) {
std::unordered_map<const Node*, TensorDimKey> vmap;
std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
const auto& axis = op.as<ComputeOpNode>()->axis;
Tensor t = op.output(0);
for (size_t i = 0; i < axis.size(); ++i) {
vmap[axis[i]->var.get()] = TensorDimKey(t, i);
std::vector<TensorDimKey> keys;
for (int j = 0; j < op->num_outputs(); ++j) {
keys.emplace_back(op.output(j), i);
}
vmap[axis[i]->var.get()] = std::move(keys);
}
auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
const NodeRef& n) {
......@@ -335,7 +340,10 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
auto it = vmap.find(call->args[i].get());
TensorDimKey src(call, static_cast<int>(i));
if (it != vmap.end()) {
f_merge_key(it->second, src);
const std::vector<TensorDimKey>& keys = it->second;
for (const auto& key : keys) {
f_merge_key(key, src);
}
} else {
if (exact_reach.count(src)) {
fail_set.insert(exact_reach.at(src));
......@@ -344,7 +352,9 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
for (auto& e : op.as<ComputeOpNode>()->body) {
ir::PostOrderVisit(e, fvisit);
}
}
}
ReachGraph reach;
......
......@@ -27,7 +27,7 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >;
/*!
* \brief The map beteen tensor and operation it feeds to.
* \brief The map between tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
......@@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots);
* The operations contains node which input-reachable from any inputs
* output reachable to any outputs.
*
* The inputs won't be included in the subgraph, the outputs will be inclued.
* The inputs won't be included in the subgraph, the outputs will be included.
*
* \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph.
......
......@@ -8,6 +8,7 @@
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./message_passing.h"
#include "../pass/ir_util.h"
namespace tvm {
......@@ -120,13 +121,13 @@ Tensor Schedule::cache_write(const Tensor& tensor,
vsub[iv->var.get()] = new_iv->var;
}
VarReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
Expr body = repl.Mutate(compute->body[tensor->value_index]);
Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, new_axis, body);
compute->name + "." + scope, new_axis, {body});
Tensor cache_tensor = cache_op.output(0);
Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->axis,
cache_tensor(args));
{cache_tensor(args)});
std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
......@@ -198,14 +199,15 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size());
std::vector<Array<Expr>> new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready;
Array<Var> args;
Expr body;
Array<Expr> body;
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
......@@ -220,11 +222,14 @@ void InjectInline(ScheduleNode* sch) {
Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) {
if (!new_body[j].defined()) {
if (!new_body[j].size()) {
new_body[j] = s->op.as<ComputeOpNode>()->body;
}
new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]),
stage->op, args, body).as<ir::Evaluate>()->value;
for (size_t k = 0; k < body.size(); ++k) {
changed[j] = true;
new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
stage->op, args, body[k]).as<ir::Evaluate>()->value);
}
}
}
}
......@@ -234,19 +239,21 @@ void InjectInline(ScheduleNode* sch) {
for (size_t i = 0; i < sch->stages.size(); ++i) {
Stage s = sch->stages[i];
if (s->attach_type == kInlinedAlready) continue;
if (new_body[i].defined()) {
if (new_body[i].size()) {
// Logics from ReplaceDataFlow
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute);
Operation op = s->op;
if (!new_body[i].same_as(compute->body)) {
if (changed[i]) {
op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]);
}
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
repl[s->op.output(0)] = op.output(0);
s->op = op;
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
s->op = op;
}
}
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
......@@ -268,15 +275,15 @@ Schedule Schedule::normalize() {
}
// Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) {
(*this)->InvalidateCache();
using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op);
const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
CHECK(compute_op) << "Can only factor ComputeOp";
CHECK(compute_op) << "Can only factor ComputeOp";
ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
{
size_t axis_pos = FindNodeRef(leaf_vars, axis);
......@@ -329,7 +336,8 @@ Tensor Schedule::rfactor(const Tensor& tensor,
}
}
// predicate generation, copy not touched axis.
const Reduce* reduce = compute_op->body.as<Reduce>();
int idx = tensor->value_index;
const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
Expr predicate = reduce->condition;
std::unordered_map<const Variable*, Expr> vsub;
......@@ -359,10 +367,18 @@ Tensor Schedule::rfactor(const Tensor& tensor,
n->reduce_axis.push_back(IterVar(ncpy));
}
}
n->body = Reduce::make(reduce->combiner,
VarReplacer(vsub).Mutate(reduce->source),
n->reduce_axis,
predicate);
VarReplacer replacer(vsub);
Array<Expr> new_source = ir::UpdateArray(reduce->source,
[&replacer] (const Expr& e) { return replacer.Mutate(e); });
std::vector<Expr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
body.emplace_back(Reduce::make(reduce->combiner,
new_source,
n->reduce_axis,
predicate,
idx));
}
n->body = Array<Expr>(body);
// refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) {
......@@ -397,26 +413,44 @@ Tensor Schedule::rfactor(const Tensor& tensor,
// Replace the old reduction.
IterVar repl_red_axis = reduce_axis(
dom_map.at(axis), axis->var->name_hint + ".v");
Tensor factor_tensor = factor_op.output(0);
Tensor old_tensor = reduce_stage->op.output(0);
Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) {
Array<Tensor> factor_tensors;
Array<Tensor> old_tensors;
int size = factor_op->num_outputs();
for (int idx = 0; idx < size; ++idx) {
factor_tensors.push_back(factor_op.output(idx));
old_tensors.push_back(reduce_stage->op.output(idx));
}
Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
[&](const Array<Var>& i) {
Array<Expr> indices;
indices.push_back(repl_red_axis->var);
for (Var v : i) {
indices.push_back(v);
}
return Reduce::make(reduce->combiner,
factor_tensor(indices), {repl_red_axis}, const_true());
}, old_tensor->op->name + ".repl");
Array<Expr> factor_exprs;
for (int idx = 0; idx < size; ++idx) {
factor_exprs.push_back(factor_tensors[idx](indices));
}
Array<Expr> reductions;
Array<IterVar> axis = {repl_red_axis};
Expr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
reductions.push_back(Reduce::make(reduce->combiner,
factor_exprs, axis, cond, idx));
}
return reductions;
}, reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
vmap[old_tensor] = repl_tensor;
for (int idx = 0; idx < size; ++idx) {
vmap[old_tensors[idx]] = repl_tensors[idx];
}
ReplaceDataFlow((*this)->stages, &vmap);
// revamp the reduction stage.
reduce_stage->op = repl_tensor->op;
reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars();
reduce_stage->op = repl_tensors[0]->op;
reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
reduce_stage->relations = Array<IterVarRelation>();
return factor_tensor;
return factor_tensors;
}
} // namespace tvm
......@@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator {
// This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) {
for (int i = 0; i < s->op->num_outputs(); ++i) {
Tensor target = s->origin_op.output(0);
Tensor target = s->origin_op.output(i);
AddReplace(s->op.output(i), target,
target, s->origin_op);
}
......
......@@ -49,7 +49,6 @@ def test_reduce_prims():
test_prim(tvm.max, np.amax)
def test_rfactor():
n = tvm.convert(1027)
A = tvm.placeholder((n,), name='A')
......@@ -128,7 +127,115 @@ def test_rfactor_threads():
check_target("metal")
check_target("opencl")
def test_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine,
fidentity,
name='argmax')
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='float32')
k = tvm.reduce_axis((0, n), 'k')
T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
s = tvm.create_schedule(T0.op)
def check_target():
device = 'cpu'
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[idx, val, T0, T1])
fargmax = tvm.build(fapi,
target='llvm',
name="argmax")
mm = 12
nn = 16
np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype('float32')
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target()
def test_rfactor_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine,
fidentity,
name='argmax')
nn = 1027
mm = 10
n = tvm.convert(nn)
m = tvm.convert(mm)
A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
k = tvm.reduce_axis((0, n))
B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')
# schedule
s = tvm.create_schedule(B0.op)
nthread = 16
ko, kf = s[B0].split(k, factor=nthread)
BF0, BF1 = s.rfactor(B0, kf)
bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
tx = s[B0].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[B0].bind(tx, thread_x)
s[BF0.op].compute_at(s[B0], tx)
s[B0].set_store_predicate(thread_x.var.equal(0))
def check_target(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A0, A1, B0, B1])
fargmax = tvm.build(fapi,
target=device,
name="argmax")
np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype('float32')
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target("cuda")
if __name__ == "__main__":
test_rfactor_threads()
test_rfactor()
test_reduce_prims()
test_argmax()
test_rfactor_argmax()
......@@ -101,8 +101,8 @@ def test_rfactor():
s = tvm.create_schedule(B.op)
BF = s.rfactor(B, k1)
assert(tuple(BF.shape) == (n, n))
assert(set(BF.op.body.axis) == set([k2]))
assert(s[B].op.body.axis[0].dom.extent == n)
assert(set(BF.op.body[0].axis) == set([k2]))
assert(s[B].op.body[0].axis[0].dom.extent == n)
assert(len(s[B].all_iter_vars) == 2)
# schedule with splot
s = tvm.create_schedule(B.op)
......@@ -111,9 +111,9 @@ def test_rfactor():
BF = s.rfactor(B, ki)
assert(BF.shape[0].value == 4)
assert(BF.shape[1] == n)
assert(BF.op.body.axis[0] == k2)
assert(BF.op.body.axis[1].var == ko.var)
assert(s[B].op.body.axis[0].dom.extent.value == 4)
assert(BF.op.body[0].axis[0] == k2)
assert(BF.op.body[0].axis[1].var == ko.var)
assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
if __name__ == "__main__":
......
......@@ -118,6 +118,43 @@ def test_extern_multi_out():
assert(len(res) == 2)
assert(res[1].value_index == 1)
def test_tuple_inputs():
m = tvm.var('m')
n = tvm.var('n')
A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1')
T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T')
s = tvm.create_schedule(T0.op)
for i in range(len(T0.shape)):
assert(T0.shape[i] == T1.shape[i])
assert(T0.op == T1.op)
assert(T0.value_index == 0)
assert(T1.value_index == 1)
def test_tuple_with_different_deps():
m = tvm.var('m')
n = tvm.var('n')
A0 = tvm.placeholder((m, n), name='A1')
A1 = tvm.placeholder((m, n), name='A2')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B')
C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=10)
s[B0.op].compute_at(s[C], xo)
sch = s.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
def get_B1_realize(x):
if isinstance(x, tvm.stmt.Realize) and \
x.func == B1.op and x.value_index == 1:
ret.append(x)
ret = []
tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize)
assert stmt.node == C.op and len(ret) == 1
if __name__ == "__main__":
test_conv1d()
......@@ -128,3 +165,5 @@ if __name__ == "__main__":
test_scan_multi_out()
test_extern()
test_extern_multi_out()
test_tuple_inputs()
test_tuple_with_different_deps()
......@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body)
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
......@@ -25,7 +25,7 @@ def test_inline2():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body)
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op):
if isinstance(op, tvm.expr.Call):
assert op.func != T.op
......
......@@ -89,7 +89,7 @@ def test_inline_mixed():
def check(x):
if isinstance(x, tvm.expr.Call):
assert x.func != A2
tvm.ir_pass.PostOrderVisit(s[C].op.body, check)
tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check)
def test_scan_inline1():
......
......@@ -125,6 +125,8 @@ np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
######################################################################
# .. _general-reduction:
#
# Define General Commutative Reduction Operation
# ----------------------------------------------
# Besides the built-in reduction operations like :any:`tvm.sum`,
......@@ -140,6 +142,12 @@ A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')
######################################################################
# .. note::
#
# Sometimes we would like to perform reduction that involves multiple
# values like :code:`argmax`, which can be done by tuple inputs.
# See :ref:`reduction-with-tuple-inputs` for more detail.
######################################################################
# Summary
......
"""
Compute and Reduction with Tuple Inputs
=======================================
**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_
Often we want to compute multiple outputs with the same shape within
a single loop or perform reduction that involves multiple values like
:code:`argmax`. These problems can be addressed by tuple inputs.
In this tutorial, we will introduce the usage of tuple inputs in TVM.
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
######################################################################
# Describe Batchwise Computation
# ------------------------------
# For operators which have the same shape, we can put them together as
# the inputs of :any:`tvm.compute`, if we wish they can be scheduled
# together in the next schedule procedure.
#
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
A1 = tvm.placeholder((m, n), name='A1')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B')
# The generated IR code would be:
s = tvm.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
######################################################################
# .. _reduction-with-tuple-inputs:
#
# Describe Reduction with Collaborative Inputs
# --------------------------------------------
# Sometimes, we requires multiple inputs to express some reduction
# operators, and the inputs will collaborate together, e.g. :code:`argmax`.
# In the reduction procedure, :code:`argmax` need to compare the value of
# operands, also need to keep the index of operand. It can be expressed
# with :any:`comm_reducer` as below:
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
# our identity element also need to be a tuple, so `fidentity` accepts
# two types as inputs.
def fidentity(t0, t1):
return tvm.const(-1, t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax')
# describe the reduction computation
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m, n), name='idx', dtype='int32')
val = tvm.placeholder((m, n), name='val', dtype='int32')
k = tvm.reduce_axis((0, n), 'k')
T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T')
# the generated IR code would be:
s = tvm.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
######################################################################
# .. note::
#
# For ones who are not familiar with reduction, please refer to
# :ref:`general-reduction`.
######################################################################
# Schedule Operation with Tuple Inputs
# ------------------------------------
# It is worth mentioning that although you will get multiple outputs
# with one batch operation, but they can only be scheduled together
# in terms of operation.
n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m, n), name='A0')
B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B')
A1 = tvm.placeholder((m, n), name='A1')
C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C')
s = tvm.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
# as you can see in the below generated IR code:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))
######################################################################
# Summary
# -------
# This tutorial introduces the usage of tuple inputs operation.
#
# - Describe normal batchwise computation.
# - Describe reduction operation with tuple inputs.
# - Notice that you can only schedule computation in terms of operation instead of tensor.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment