Commit 26d91985 by ziheng Committed by Tianqi Chen

[LANG] CommReducer (#103)

* [LANG] CommReducer

* Reorganize c_api

* Remove InitValue and Combine; refactor Functor

* Make CommReducer an Expr

* Make comm_reducer type independent

* Make CommReducerNode a Node

* Small fix

* Refine

* Refine front api; add integration testcases for min/max

* Fix python

* Refine

* Fix lint and add example
parent 7e032457
...@@ -22,12 +22,67 @@ using Halide::Internal::IRNodeType; ...@@ -22,12 +22,67 @@ using Halide::Internal::IRNodeType;
using Halide::Internal::ForType; using Halide::Internal::ForType;
using Halide::DeviceAPI; using Halide::DeviceAPI;
/*! \brief Reduction operator operator */ // Node container for CommReducer
struct Reduce : public ExprNode<Reduce> { struct CommReducerNode;
struct CommReducer : public NodeRef {
CommReducer() {}
explicit CommReducer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const CommReducerNode* get() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const CommReducerNode* operator->() const;
/*! \brief type indicate the container type */
using ContainerType = CommReducerNode;
};
/*!
* \brief A commutative reducer node to represent a commutative
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */
Array<Var> args;
/*! \brief The result of reducer */
Expr result;
/*! /*!
* \brief The binary operator of reduction * \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/ */
std::string op; Expr identity_element;
/*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const;
/*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
static constexpr const char* _type_key = "CommReducer";
TVM_DECLARE_NODE_TYPE_INFO(CommReducerNode, Node);
};
inline const CommReducerNode* CommReducer::get() const {
return static_cast<CommReducerNode*>(node_.get());
}
inline const CommReducerNode* CommReducer::operator->() const {
return static_cast<CommReducerNode*>(node_.get());
}
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */ /*! \brief The source operand */
Expr source; Expr source;
/*! \brief The reduction axis */ /*! \brief The reduction axis */
...@@ -39,37 +94,19 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -39,37 +94,19 @@ struct Reduce : public ExprNode<Reduce> {
Expr condition; Expr condition;
/*! \brief construct expr from op and rdom */ /*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, static Expr make(CommReducer combiner,
Expr src,
Array<IterVar> rdom, Array<IterVar> rdom,
Expr condition = const_true()); Expr condition = const_true());
/*!
* \brief Get initial value for reduction.
* \param op The operator
* \param type The data type.
* \return The initial value that can be assigned to reduction.
*/
static Expr InitValue(const std::string& op, Type type);
/*!
* \brief Combine two values with given reduction.
* \param op The operator
* \param a The left operand.
* \param b The left operand.
* \return The combined reduction result.
*/
static Expr Combine(const std::string& op, Expr a, Expr b);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source); v->Visit("source", &source);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("condition", &condition); v->Visit("condition", &condition);
} }
static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
static constexpr const char* Add = "Add";
static constexpr const char* Max = "Max";
static constexpr const char* Min = "Min";
}; };
/*! /*!
...@@ -93,26 +130,20 @@ struct TensorKey { ...@@ -93,26 +130,20 @@ struct TensorKey {
/*! \brief namespace of possible attribute sin AttrStmt.type_key */ /*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr { namespace attr {
// The above attr does not pass to ir stage. // The above attr does not pass to ir stage.
/*! /*! \brief Mark launching extent of thread, used by device API. */
* \brief Mark launching extent of thread, used by device API.
*/
constexpr const char* thread_extent = "thread_extent"; constexpr const char* thread_extent = "thread_extent";
/*! /*! \brief Mark launching of a virtual thread. */
* \brief Mark launching of a virtual thread.
*/
constexpr const char* virtual_thread = "virtual_thread"; constexpr const char* virtual_thread = "virtual_thread";
/*! /*! \brief Mark the scope as volatile access for certain handle. */
* \brief Mark the scope as volatile access for certain handle.
*/
constexpr const char* volatile_scope = "volatile_scope"; constexpr const char* volatile_scope = "volatile_scope";
/*! /*! \brief Mark storage scope of buffers */
* \brief Mark storage scope of buffers
*/
constexpr const char* storage_scope = "storage_scope"; constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage scope of realization */ /*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope"; constexpr const char* realize_scope = "realize_scope";
/*! \brief Mark of loop scope */ /*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope"; constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark of scan update scope */ /*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope"; constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
......
...@@ -84,6 +84,7 @@ Stmt CanonicalSimplify(Stmt stmt); ...@@ -84,6 +84,7 @@ Stmt CanonicalSimplify(Stmt stmt);
* \return The converted form. * \return The converted form.
*/ */
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map); Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*! /*!
* \brief inline all calls of f in stmt. * \brief inline all calls of f in stmt.
......
...@@ -21,6 +21,15 @@ int32 = "int32" ...@@ -21,6 +21,15 @@ int32 = "int32"
float32 = "float32" float32 = "float32"
handle = "handle" handle = "handle"
def min_value(dtype):
return _api_internal._min_value(dtype)
def max_value(dtype):
return _api_internal._max_value(dtype)
def const(value, dtype=None): def const(value, dtype=None):
"""construct a constant""" """construct a constant"""
if dtype is None: if dtype is None:
...@@ -414,93 +423,106 @@ def reduce_axis(dom, name="rv"): ...@@ -414,93 +423,106 @@ def reduce_axis(dom, name="rv"):
return _IterVar(dom, name, 2) return _IterVar(dom, name, 2)
def sum(expr, axis, where=None): def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a sum expression over axis """Create a commutative reducer for reduction.
Parameters Parameters
---------- ----------
expr : Expr fcombine : function(Expr -> Expr -> Expr)
The source expression. A binary function which takes two Expr as input to return a Expr.
axis : IterVar
The reduction IterVar axis
where : optional, Expr fidentity : function(str -> Expr)
Filtering predicate of the reduction. A function which takes a type string as input to return a const Expr.
Returns Returns
------- -------
value : Expr reducer : function
The result value. A function which creates a reduce expression over axis. There are two
""" to use it:
axis = axis if isinstance(axis, list) else [axis] 1. accept (expr, axis, where) to produce an Reduce Expr on
x = _make.Reduce("Add", expr, axis, where) specified axis;
return x 2. simply use it with multiple Exprs.
def min(lhs, rhs=None, axis=None, where=None):
"""Create a min expression.
Parameters
----------
lhs : Expr
The left hand expression.
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns Example
------- -------
value : Expr .. code-block:: python
The result value. n = tvm.var('n')
m = tvm.var('m')
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
""" """
if rhs and axis: def _reduce_directly(*args):
raise ValueError("Can only take one argument, rhs or axis") num = len(args)
if isinstance(rhs, (_schedule.IterVar, list)): # process `where` is None
axis, rhs = rhs, axis if num == 3 and args[2] is None:
if rhs: num = 2
return _make.Min(lhs, rhs) res = args[0]
axis = axis if isinstance(axis, list) else [axis] for i in range(num-1):
x = _make.Reduce("Min", expr, axis, where) res = fcombine(res, args[i+1])
return x return res
def _make_reduce(expr, axis, where=None):
def max(lhs, rhs=None, axis=None, where=None): expr = convert(expr)
"""Create a max expression. dtype = expr.dtype
code = fcombine.__code__
Parameters assert fcombine.__code__.co_argcount == 2
---------- arg_vars = [var(name, dtype) for name in code.co_varnames]
lhs : Expr result = fcombine(*[v for v in arg_vars])
The left hand expression. result = convert(result)
id_elem = fidentity(dtype)
rhs : Expr, optional assert isinstance(id_elem, _expr.Expr)
The right hand expression. combiner = _make.CommReducer(arg_vars, result, id_elem)
axis = axis if isinstance(axis, list) else [axis]
axis : IterVar, optional return _make.Reduce(combiner, expr, axis, where)
The reduction IterVar axis
def reducer(expr, axis, where=None, *args):
where : optional, Expr if isinstance(axis, (_schedule.IterVar, list)):
Filtering predicate of the reduction. assert len(args) == 0
return _make_reduce(expr, axis, where)
else:
if where is None:
assert len(args) == 0
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : Expr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : Expr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: {0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = {0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
Returns
-------
value : Expr
The result value.
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_schedule.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis, where)
return x
_init_api("tvm.api") _init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make.Max(x, y), min_value, name='max')
...@@ -56,6 +56,12 @@ TVM_REGISTER_API("make.Allocate") ...@@ -56,6 +56,12 @@ TVM_REGISTER_API("make.Allocate")
args[4]); args[4]);
}); });
TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
});
// make from two arguments // make from two arguments
#define REGISTER_MAKE1(Node) \ #define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
......
...@@ -13,6 +13,18 @@ ...@@ -13,6 +13,18 @@
namespace tvm { namespace tvm {
TVM_REGISTER_API("_min_value")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Type t = args[0].operator Type();
*ret = t.min();
});
TVM_REGISTER_API("_max_value")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Type t = args[0].operator Type();
*ret = t.max();
});
TVM_REGISTER_API("_const") TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) { if (args[0].type_code() == kInt) {
...@@ -344,4 +356,11 @@ TVM_REGISTER_API("_ScheduleRFactor") ...@@ -344,4 +356,11 @@ TVM_REGISTER_API("_ScheduleRFactor")
.rfactor(args[1], args[2]); .rfactor(args[1], args[2]);
}); });
TVM_REGISTER_API("_CommReducerCombine")
.set_body([](TVMArgs args, TVMRetValue* ret) {
const ir::CommReducerNode* combiner =
args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
*ret = (*combiner)(args[1], args[2]);
});
} // namespace tvm } // namespace tvm
...@@ -47,15 +47,27 @@ IterVar reduce_axis(Range dom, std::string name) { ...@@ -47,15 +47,27 @@ IterVar reduce_axis(Range dom, std::string name) {
} }
Expr sum(Expr source, Array<IterVar> rdom) { Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Add", source, rdom, make_const(Bool(1), true)); Var x("x"), y("y");
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
} }
Expr max(Expr source, Array<IterVar> rdom) { Expr max(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Max", source, rdom, make_const(Bool(1), true)); Var x("x"), y("y");
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
} }
Expr min(Expr source, Array<IterVar> rdom) { Expr min(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Min", source, rdom, make_const(Bool(1), true)); Var x("x"), y("y");
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
} }
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <ir/IR.h> #include <ir/IR.h>
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <memory> #include <memory>
...@@ -12,6 +13,7 @@ ...@@ -12,6 +13,7 @@
namespace Halide { namespace Halide {
namespace Internal { namespace Internal {
using tvm::ir::CommReducerNode;
using tvm::ir::Reduce; using tvm::ir::Reduce;
using tvm::ir::AttrStmt; using tvm::ir::AttrStmt;
...@@ -22,8 +24,8 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { ...@@ -22,8 +24,8 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(" p->stream << "reduce(combiner="
<< op->op << op->combiner
<< ", "; << ", ";
p->print(op->source); p->print(op->source);
p->stream << ", axis=" << op->axis; p->stream << ", axis=" << op->axis;
...@@ -33,13 +35,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -33,13 +35,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")"; p->stream << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result="
<< op->result
<< ", args=" << op->args
<< ", identity_element="
<< op->identity_element
<< ")";
});
} // namespace Internal } // namespace Internal
} // namespace Halide } // namespace Halide
namespace tvm { namespace tvm {
namespace ir { namespace ir {
Expr Reduce::make(std::string op, Expr source, CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) {
auto node = std::make_shared<CommReducerNode>();
node->args = args;
node->result = result;
node->identity_element = identity_element;
return CommReducer(node);
}
Expr CommReducerNode::operator()(Expr a, Expr b) const {
Map<Var, Expr> value_map;
value_map.Set(args[0], a);
value_map.Set(args[1], b);
return Substitute(result, value_map);
}
Expr Reduce::make(CommReducer combiner, Expr source,
Array<IterVar> axis, Expr condition) { Array<IterVar> axis, Expr condition) {
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce) CHECK_EQ(axis[i]->iter_type, kCommReduce)
...@@ -54,39 +80,13 @@ Expr Reduce::make(std::string op, Expr source, ...@@ -54,39 +80,13 @@ Expr Reduce::make(std::string op, Expr source,
CHECK(axis[i].defined()); CHECK(axis[i].defined());
} }
n->type = source.type(); n->type = source.type();
n->combiner = combiner;
n->source = source; n->source = source;
n->op = op;
n->axis = axis; n->axis = axis;
n->condition = condition; n->condition = condition;
return Expr(n); return Expr(n);
} }
Expr Reduce::InitValue(const std::string& op, Type type) {
if (op == "Add") {
return make_zero(type);
} else if (op == "Max") {
return type.min();
} else if (op == "Min") {
return type.max();
} else {
LOG(FATAL) << "Unsupported reduction " << op;
return Expr();
}
}
Expr Reduce::Combine(const std::string& op, Expr a, Expr b) {
if (op == "Add") {
return Add::make(a, b);
} else if (op == "Max") {
return Max::make(a, b);
} else if (op == "Min") {
return Min::make(a, b);
} else {
LOG(FATAL) << "Unsupported reduction " << op;
return Expr();
}
}
TVM_REGISTER_NODE_TYPE(Reduce); TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt); TVM_REGISTER_NODE_TYPE(AttrStmt);
......
...@@ -173,8 +173,10 @@ void MakeReduction(const ComputeOpNode* op, ...@@ -173,8 +173,10 @@ void MakeReduction(const ComputeOpNode* op,
} }
const Reduce* reduce = op->body.as<Reduce>(); const Reduce* reduce = op->body.as<Reduce>();
CHECK(reduce); CHECK(reduce);
Expr init_value = Reduce::InitValue(reduce->op, reduce->type); const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
Expr update_value = Reduce::Combine(reduce->op, t(args), reduce->source); CHECK(combiner);
Expr init_value = combiner->identity_element;
Expr update_value = (*combiner)(t(args), reduce->source);
*init = Provide::make(t->op, t->value_index, init_value, args); *init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args); *provide = Provide::make(t->op, t->value_index, update_value, args);
if (!is_one(reduce->condition)) { if (!is_one(reduce->condition)) {
...@@ -237,7 +239,6 @@ Stmt MakeCrossThreadReduction( ...@@ -237,7 +239,6 @@ Stmt MakeCrossThreadReduction(
} }
Var res_handle("reduce_temp", Handle()); Var res_handle("reduce_temp", Handle());
Array<Expr> freduce_args; Array<Expr> freduce_args;
freduce_args.push_back(StringImm::make(reduce->op));
freduce_args.push_back(reduce->source); freduce_args.push_back(reduce->source);
freduce_args.push_back(cond); freduce_args.push_back(cond);
...@@ -253,12 +254,17 @@ Stmt MakeCrossThreadReduction( ...@@ -253,12 +254,17 @@ Stmt MakeCrossThreadReduction(
} }
} }
} }
Stmt reduce_body = Store::make( Stmt reduce_body = Store::make(res_handle,
res_handle, Call::make( Call::make(
reduce->type, reduce->type,
ir::intrinsic::tvm_thread_allreduce, ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic), freduce_args, Call::Intrinsic),
0); 0);
reduce_body = AttrStmt::make(
reduce->combiner,
attr::reduce_scope,
make_zero(reduce->type),
reduce_body);
Stmt assign_body = Provide::make( Stmt assign_body = Provide::make(
stage->op, 0, Load::make(reduce->type, res_handle, 0), args); stage->op, 0, Load::make(reduce->type, res_handle, 0), args);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
......
...@@ -328,7 +328,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { ...@@ -328,7 +328,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
op->condition.same_as(new_cond)) { op->condition.same_as(new_cond)) {
return e; return e;
} else { } else {
return Reduce::make(op->op, new_source, new_axis, new_cond); return Reduce::make(op->combiner, new_source, new_axis, new_cond);
} }
} }
......
...@@ -34,6 +34,13 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -34,6 +34,13 @@ class ThreadAllreduceBuilder : public IRMutator {
} else { } else {
return ret; return ret;
} }
} else if (op->attr_key == attr::reduce_scope) {
const CommReducerNode *combiner = op->node.as<CommReducerNode>();
CHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = IRMutator::Mutate_(op, s);
reduce_combiner_.pop_back();
return ret;
} else { } else {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
...@@ -91,16 +98,17 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -91,16 +98,17 @@ class ThreadAllreduceBuilder : public IRMutator {
}; };
// make allreduce. // make allreduce.
Stmt MakeAllreduce(const Store* op, const Call* call) { Stmt MakeAllreduce(const Store* op, const Call* call) {
const std::string& op_code = call->args[0].as<StringImm>()->value; CHECK(!reduce_combiner_.empty());
Expr value = call->args[1]; const CommReducerNode *combiner = reduce_combiner_.back();
Expr cond = call->args[2]; Expr init = combiner->identity_element;
Expr value = call->args[0];
Expr cond = call->args[1];
if (!is_one(cond)) { if (!is_one(cond)) {
value = Select::make( value = Select::make(cond, value, init);
cond, value, Reduce::InitValue(op_code, value.type()));
} }
std::unordered_set<const Variable*> reduce_set; std::unordered_set<const Variable*> reduce_set;
for (size_t i = 3; i < call->args.size(); ++i) { for (size_t i = 2; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>(); const Variable* v = call->args[i].as<Variable>();
CHECK(v); CHECK(v);
reduce_set.insert(v); reduce_set.insert(v);
...@@ -150,7 +158,7 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -150,7 +158,7 @@ class ThreadAllreduceBuilder : public IRMutator {
BufIndex(reduce_index, group_index, reduce_extent))); BufIndex(reduce_index, group_index, reduce_extent)));
seq.emplace_back(SyncThread()); seq.emplace_back(SyncThread());
seq.emplace_back(MakeBufAllreduce( seq.emplace_back(MakeBufAllreduce(
op_code, value.type(), shared_buf, combiner, value.type(), shared_buf,
reduce_index, group_index, reduce_extent, threadx_extent)); reduce_index, group_index, reduce_extent, threadx_extent));
CHECK(!load_remap_.count(op->buffer_var.get())); CHECK(!load_remap_.count(op->buffer_var.get()));
load_remap_[op->buffer_var.get()] = load_remap_[op->buffer_var.get()] =
...@@ -164,7 +172,7 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -164,7 +172,7 @@ class ThreadAllreduceBuilder : public IRMutator {
return MergeSeq(seq); return MergeSeq(seq);
} }
// make allreduce. // make allreduce.
Stmt MakeBufAllreduce(const std::string& op, Stmt MakeBufAllreduce(const CommReducerNode *combiner,
Type type, Type type,
Var shared_buf, Var shared_buf,
Expr reduce_index, Expr reduce_index,
...@@ -186,7 +194,7 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -186,7 +194,7 @@ class ThreadAllreduceBuilder : public IRMutator {
type, shared_buf, type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent)); BufIndex(reduce_index + offset, group_index, reduce_extent));
Expr a = Load::make(type, shared_buf, buf_index); Expr a = Load::make(type, shared_buf, buf_index);
return Store::make(shared_buf, Reduce::Combine(op, a, b), buf_index); return Store::make(shared_buf, (*combiner)(a, b), buf_index);
}; };
// Step one, check for // Step one, check for
if (reduce_align > reduce_extent) { if (reduce_align > reduce_extent) {
...@@ -260,6 +268,7 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -260,6 +268,7 @@ class ThreadAllreduceBuilder : public IRMutator {
// surrounding scope of thread extent. // surrounding scope of thread extent.
std::vector<const AttrStmt*> thread_extents_; std::vector<const AttrStmt*> thread_extents_;
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap // The load remap
std::unordered_map<const Variable *, Expr> load_remap_; std::unordered_map<const Variable *, Expr> load_remap_;
// Allocate remap // Allocate remap
......
...@@ -57,6 +57,15 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) { ...@@ -57,6 +57,15 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
return m.Mutate(stmt); return m.Mutate(stmt);
} }
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
if (value_map.size() == 0) return expr;
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first.get()] = kv.second;
}
return m.Mutate(expr);
}
class ExprUseVarVisitor : public IRVisitor { class ExprUseVarVisitor : public IRVisitor {
public: public:
explicit ExprUseVarVisitor(const Variable* var) explicit ExprUseVarVisitor(const Variable* var)
......
...@@ -344,7 +344,7 @@ Tensor Schedule::rfactor(const Tensor& tensor, ...@@ -344,7 +344,7 @@ Tensor Schedule::rfactor(const Tensor& tensor,
n->reduce_axis.push_back(IterVar(ncpy)); n->reduce_axis.push_back(IterVar(ncpy));
} }
} }
n->body = Reduce::make(reduce->op, n->body = Reduce::make(reduce->combiner,
VarReplacer(vsub).Mutate(reduce->source), VarReplacer(vsub).Mutate(reduce->source),
n->reduce_axis, n->reduce_axis,
predicate); predicate);
...@@ -390,8 +390,8 @@ Tensor Schedule::rfactor(const Tensor& tensor, ...@@ -390,8 +390,8 @@ Tensor Schedule::rfactor(const Tensor& tensor,
for (Var v : i) { for (Var v : i) {
indices.push_back(v); indices.push_back(v);
} }
return Reduce::make( return Reduce::make(reduce->combiner,
reduce->op, factor_tensor(indices), {repl_red_axis}, const_true()); factor_tensor(indices), {repl_red_axis}, const_true());
}, old_tensor->op->name + ".repl"); }, old_tensor->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
......
import tvm import tvm
import numpy as np import numpy as np
def test_sum(): def test_reduce_prims():
# graph def test_prim(reducer, np_reducer):
n = tvm.var('n') # graph
m = tvm.var('m') n = tvm.var('n')
A = tvm.placeholder((n, m), name='A') m = tvm.var('m')
k = tvm.reduce_axis((0, m)) A = tvm.placeholder((n, m), name='A')
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B') k = tvm.reduce_axis((0, m))
# schedule B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(i>1)), name='B')
s = tvm.create_schedule(B.op) # schedule
# create iter var and assign them tags. s = tvm.create_schedule(B.op)
num_thread = 1 # create iter var and assign them tags.
xo, xi = s[B].split(B.op.axis[0], factor=num_thread) num_thread = 1
s[B].bind(xo, tvm.thread_axis("blockIdx.x")) xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xi, tvm.thread_axis("threadIdx.x")) s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host): if not tvm.codegen.enabled(host):
return return
if not tvm.codegen.enabled(device): if not tvm.codegen.enabled(device):
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
fsum = tvm.build(s, freduce = tvm.build(s,
args=[A, B], args=[A, B],
target=device, target_host=host, target=device, target_host=host,
name="mysum") name="myreduce")
print(fsum.imported_modules[0].get_source()) print(freduce.imported_modules[0].get_source())
# launch the kernel. # launch the kernel.
n = 1028 n = 1028
m = 129 m = 129
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx) x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b) freduce(x, y)
res = np.sum(a.asnumpy(), axis=1) npy = y.asnumpy()
res[:2] = 0 npy[:2] = 0
np.testing.assert_allclose( res = np_reducer(x.asnumpy(), axis=1)
b.asnumpy(), res, rtol=1e-4) res[:2] = 0
np.testing.assert_allclose(npy, res, rtol=1e-4)
if tvm.module.enabled("opencl"): if tvm.module.enabled("opencl"):
tvm.module.init_opencl() tvm.module.init_opencl()
check_device("cuda")
check_device("opencl")
test_prim(tvm.sum, np.sum)
test_prim(tvm.min, np.amin)
test_prim(tvm.max, np.amax)
check_device("cuda")
check_device("opencl")
def test_rfactor(): def test_rfactor():
...@@ -127,4 +133,4 @@ def test_rfactor_threads(): ...@@ -127,4 +133,4 @@ def test_rfactor_threads():
if __name__ == "__main__": if __name__ == "__main__":
test_rfactor_threads() test_rfactor_threads()
test_rfactor() test_rfactor()
test_sum() test_reduce_prims()
...@@ -33,6 +33,20 @@ def test_tensor_slice(): ...@@ -33,6 +33,20 @@ def test_tensor_slice():
B = tvm.compute((n,), lambda i: A[0][i] + A[0][i]) B = tvm.compute((n,), lambda i: A[0][i] + A[0][i])
def test_tensor_comm_reducer():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
k = tvm.reduce_axis((0, n), "k")
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k))
def test_tensor_comm_reducer_overload():
m = tvm.var('m')
n = tvm.var('n')
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
sum_res = mysum(m, n)
def test_tensor_reduce(): def test_tensor_reduce():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment