Commit 26d91985 by ziheng Committed by Tianqi Chen

[LANG] CommReducer (#103)

* [LANG] CommReducer

* Reorganize c_api

* Remove InitValue and Combine; refactor Functor

* Make CommReducer an Expr

* Make comm_reducer type independent

* Make CommReducerNode a Node

* Small fix

* Refine

* Refine front api; add integration testcases for min/max

* Fix python

* Refine

* Fix lint and add example
parent 7e032457
......@@ -22,12 +22,67 @@ using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
using Halide::DeviceAPI;
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
// Node container for CommReducer
struct CommReducerNode;
struct CommReducer : public NodeRef {
CommReducer() {}
explicit CommReducer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const CommReducerNode* get() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const CommReducerNode* operator->() const;
/*! \brief type indicate the container type */
using ContainerType = CommReducerNode;
};
/*!
* \brief A commutative reducer node to represent a commutative
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */
Array<Var> args;
/*! \brief The result of reducer */
Expr result;
/*!
* \brief The binary operator of reduction
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
std::string op;
Expr identity_element;
/*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const;
/*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
static constexpr const char* _type_key = "CommReducer";
TVM_DECLARE_NODE_TYPE_INFO(CommReducerNode, Node);
};
inline const CommReducerNode* CommReducer::get() const {
return static_cast<CommReducerNode*>(node_.get());
}
inline const CommReducerNode* CommReducer::operator->() const {
return static_cast<CommReducerNode*>(node_.get());
}
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction axis */
......@@ -39,37 +94,19 @@ struct Reduce : public ExprNode<Reduce> {
Expr condition;
/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src,
static Expr make(CommReducer combiner,
Expr src,
Array<IterVar> rdom,
Expr condition = const_true());
/*!
* \brief Get initial value for reduction.
* \param op The operator
* \param type The data type.
* \return The initial value that can be assigned to reduction.
*/
static Expr InitValue(const std::string& op, Type type);
/*!
* \brief Combine two values with given reduction.
* \param op The operator
* \param a The left operand.
* \param b The left operand.
* \return The combined reduction result.
*/
static Expr Combine(const std::string& op, Expr a, Expr b);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
static constexpr const char* Add = "Add";
static constexpr const char* Max = "Max";
static constexpr const char* Min = "Min";
};
/*!
......@@ -93,26 +130,20 @@ struct TensorKey {
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
// The above attr does not pass to ir stage.
/*!
* \brief Mark launching extent of thread, used by device API.
*/
/*! \brief Mark launching extent of thread, used by device API. */
constexpr const char* thread_extent = "thread_extent";
/*!
* \brief Mark launching of a virtual thread.
*/
/*! \brief Mark launching of a virtual thread. */
constexpr const char* virtual_thread = "virtual_thread";
/*!
* \brief Mark the scope as volatile access for certain handle.
*/
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark storage scope of buffers
*/
/*! \brief Mark storage scope of buffers */
constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope";
/*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
......
......@@ -84,6 +84,7 @@ Stmt CanonicalSimplify(Stmt stmt);
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
......
......@@ -21,6 +21,15 @@ int32 = "int32"
float32 = "float32"
handle = "handle"
def min_value(dtype):
return _api_internal._min_value(dtype)
def max_value(dtype):
return _api_internal._max_value(dtype)
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
......@@ -414,93 +423,106 @@ def reduce_axis(dom, name="rv"):
return _IterVar(dom, name, 2)
def sum(expr, axis, where=None):
"""Create a sum expression over axis
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
expr : Expr
The source expression.
axis : IterVar
The reduction IterVar axis
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
where : optional, Expr
Filtering predicate of the reduction.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
value : Expr
The result value.
"""
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis, where)
return x
def min(lhs, rhs=None, axis=None, where=None):
"""Create a min expression.
Parameters
----------
lhs : Expr
The left hand expression.
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
reducer : function
A function which creates a reduce expression over axis. There are two
to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Returns
Example
-------
value : Expr
The result value.
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_schedule.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis, where)
return x
def max(lhs, rhs=None, axis=None, where=None):
"""Create a max expression.
Parameters
----------
lhs : Expr
The left hand expression.
rhs : Expr, optional
The right hand expression.
axis : IterVar, optional
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num-1):
res = fcombine(res, args[i+1])
return res
def _make_reduce(expr, axis, where=None):
expr = convert(expr)
dtype = expr.dtype
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
arg_vars = [var(name, dtype) for name in code.co_varnames]
result = fcombine(*[v for v in arg_vars])
result = convert(result)
id_elem = fidentity(dtype)
assert isinstance(id_elem, _expr.Expr)
combiner = _make.CommReducer(arg_vars, result, id_elem)
axis = axis if isinstance(axis, list) else [axis]
return _make.Reduce(combiner, expr, axis, where)
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)):
assert len(args) == 0
return _make_reduce(expr, axis, where)
else:
if where is None:
assert len(args) == 0
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : Expr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : Expr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: {0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = {0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
Returns
-------
value : Expr
The result value.
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
if isinstance(rhs, (_schedule.IterVar, list)):
axis, rhs = rhs, axis
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis, where)
return x
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make.Max(x, y), min_value, name='max')
......@@ -56,6 +56,12 @@ TVM_REGISTER_API("make.Allocate")
args[4]);
});
TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \
......
......@@ -13,6 +13,18 @@
namespace tvm {
TVM_REGISTER_API("_min_value")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Type t = args[0].operator Type();
*ret = t.min();
});
TVM_REGISTER_API("_max_value")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Type t = args[0].operator Type();
*ret = t.max();
});
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) {
......@@ -344,4 +356,11 @@ TVM_REGISTER_API("_ScheduleRFactor")
.rfactor(args[1], args[2]);
});
TVM_REGISTER_API("_CommReducerCombine")
.set_body([](TVMArgs args, TVMRetValue* ret) {
const ir::CommReducerNode* combiner =
args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
*ret = (*combiner)(args[1], args[2]);
});
} // namespace tvm
......@@ -47,15 +47,27 @@ IterVar reduce_axis(Range dom, std::string name) {
}
Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Add", source, rdom, make_const(Bool(1), true));
Var x("x"), y("y");
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
}
Expr max(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Max", source, rdom, make_const(Bool(1), true));
Var x("x"), y("y");
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
}
Expr min(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Min", source, rdom, make_const(Bool(1), true));
Var x("x"), y("y");
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
......
......@@ -5,6 +5,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
......@@ -12,6 +13,7 @@
namespace Halide {
namespace Internal {
using tvm::ir::CommReducerNode;
using tvm::ir::Reduce;
using tvm::ir::AttrStmt;
......@@ -22,8 +24,8 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce("
<< op->op
p->stream << "reduce(combiner="
<< op->combiner
<< ", ";
p->print(op->source);
p->stream << ", axis=" << op->axis;
......@@ -33,13 +35,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result="
<< op->result
<< ", args=" << op->args
<< ", identity_element="
<< op->identity_element
<< ")";
});
} // namespace Internal
} // namespace Halide
namespace tvm {
namespace ir {
Expr Reduce::make(std::string op, Expr source,
CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) {
auto node = std::make_shared<CommReducerNode>();
node->args = args;
node->result = result;
node->identity_element = identity_element;
return CommReducer(node);
}
Expr CommReducerNode::operator()(Expr a, Expr b) const {
Map<Var, Expr> value_map;
value_map.Set(args[0], a);
value_map.Set(args[1], b);
return Substitute(result, value_map);
}
Expr Reduce::make(CommReducer combiner, Expr source,
Array<IterVar> axis, Expr condition) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
......@@ -54,39 +80,13 @@ Expr Reduce::make(std::string op, Expr source,
CHECK(axis[i].defined());
}
n->type = source.type();
n->combiner = combiner;
n->source = source;
n->op = op;
n->axis = axis;
n->condition = condition;
return Expr(n);
}
Expr Reduce::InitValue(const std::string& op, Type type) {
if (op == "Add") {
return make_zero(type);
} else if (op == "Max") {
return type.min();
} else if (op == "Min") {
return type.max();
} else {
LOG(FATAL) << "Unsupported reduction " << op;
return Expr();
}
}
Expr Reduce::Combine(const std::string& op, Expr a, Expr b) {
if (op == "Add") {
return Add::make(a, b);
} else if (op == "Max") {
return Max::make(a, b);
} else if (op == "Min") {
return Min::make(a, b);
} else {
LOG(FATAL) << "Unsupported reduction " << op;
return Expr();
}
}
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
......
......@@ -173,8 +173,10 @@ void MakeReduction(const ComputeOpNode* op,
}
const Reduce* reduce = op->body.as<Reduce>();
CHECK(reduce);
Expr init_value = Reduce::InitValue(reduce->op, reduce->type);
Expr update_value = Reduce::Combine(reduce->op, t(args), reduce->source);
const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
CHECK(combiner);
Expr init_value = combiner->identity_element;
Expr update_value = (*combiner)(t(args), reduce->source);
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
if (!is_one(reduce->condition)) {
......@@ -237,7 +239,6 @@ Stmt MakeCrossThreadReduction(
}
Var res_handle("reduce_temp", Handle());
Array<Expr> freduce_args;
freduce_args.push_back(StringImm::make(reduce->op));
freduce_args.push_back(reduce->source);
freduce_args.push_back(cond);
......@@ -253,12 +254,17 @@ Stmt MakeCrossThreadReduction(
}
}
}
Stmt reduce_body = Store::make(
res_handle, Call::make(
reduce->type,
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic),
0);
Stmt reduce_body = Store::make(res_handle,
Call::make(
reduce->type,
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic),
0);
reduce_body = AttrStmt::make(
reduce->combiner,
attr::reduce_scope,
make_zero(reduce->type),
reduce_body);
Stmt assign_body = Provide::make(
stage->op, 0, Load::make(reduce->type, res_handle, 0), args);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
......
......@@ -328,7 +328,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
op->condition.same_as(new_cond)) {
return e;
} else {
return Reduce::make(op->op, new_source, new_axis, new_cond);
return Reduce::make(op->combiner, new_source, new_axis, new_cond);
}
}
......
......@@ -34,6 +34,13 @@ class ThreadAllreduceBuilder : public IRMutator {
} else {
return ret;
}
} else if (op->attr_key == attr::reduce_scope) {
const CommReducerNode *combiner = op->node.as<CommReducerNode>();
CHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = IRMutator::Mutate_(op, s);
reduce_combiner_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
......@@ -91,16 +98,17 @@ class ThreadAllreduceBuilder : public IRMutator {
};
// make allreduce.
Stmt MakeAllreduce(const Store* op, const Call* call) {
const std::string& op_code = call->args[0].as<StringImm>()->value;
Expr value = call->args[1];
Expr cond = call->args[2];
CHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back();
Expr init = combiner->identity_element;
Expr value = call->args[0];
Expr cond = call->args[1];
if (!is_one(cond)) {
value = Select::make(
cond, value, Reduce::InitValue(op_code, value.type()));
value = Select::make(cond, value, init);
}
std::unordered_set<const Variable*> reduce_set;
for (size_t i = 3; i < call->args.size(); ++i) {
for (size_t i = 2; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>();
CHECK(v);
reduce_set.insert(v);
......@@ -150,7 +158,7 @@ class ThreadAllreduceBuilder : public IRMutator {
BufIndex(reduce_index, group_index, reduce_extent)));
seq.emplace_back(SyncThread());
seq.emplace_back(MakeBufAllreduce(
op_code, value.type(), shared_buf,
combiner, value.type(), shared_buf,
reduce_index, group_index, reduce_extent, threadx_extent));
CHECK(!load_remap_.count(op->buffer_var.get()));
load_remap_[op->buffer_var.get()] =
......@@ -164,7 +172,7 @@ class ThreadAllreduceBuilder : public IRMutator {
return MergeSeq(seq);
}
// make allreduce.
Stmt MakeBufAllreduce(const std::string& op,
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
Type type,
Var shared_buf,
Expr reduce_index,
......@@ -186,7 +194,7 @@ class ThreadAllreduceBuilder : public IRMutator {
type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent));
Expr a = Load::make(type, shared_buf, buf_index);
return Store::make(shared_buf, Reduce::Combine(op, a, b), buf_index);
return Store::make(shared_buf, (*combiner)(a, b), buf_index);
};
// Step one, check for
if (reduce_align > reduce_extent) {
......@@ -260,6 +268,7 @@ class ThreadAllreduceBuilder : public IRMutator {
// surrounding scope of thread extent.
std::vector<const AttrStmt*> thread_extents_;
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
std::unordered_map<const Variable *, Expr> load_remap_;
// Allocate remap
......
......@@ -57,6 +57,15 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
return m.Mutate(stmt);
}
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
if (value_map.size() == 0) return expr;
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first.get()] = kv.second;
}
return m.Mutate(expr);
}
class ExprUseVarVisitor : public IRVisitor {
public:
explicit ExprUseVarVisitor(const Variable* var)
......
......@@ -344,7 +344,7 @@ Tensor Schedule::rfactor(const Tensor& tensor,
n->reduce_axis.push_back(IterVar(ncpy));
}
}
n->body = Reduce::make(reduce->op,
n->body = Reduce::make(reduce->combiner,
VarReplacer(vsub).Mutate(reduce->source),
n->reduce_axis,
predicate);
......@@ -390,8 +390,8 @@ Tensor Schedule::rfactor(const Tensor& tensor,
for (Var v : i) {
indices.push_back(v);
}
return Reduce::make(
reduce->op, factor_tensor(indices), {repl_red_axis}, const_true());
return Reduce::make(reduce->combiner,
factor_tensor(indices), {repl_red_axis}, const_true());
}, old_tensor->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
......
import tvm
import numpy as np
def test_sum():
# graph
n = tvm.var('n')
m = tvm.var('m')
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
# schedule
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
def test_reduce_prims():
def test_prim(reducer, np_reducer):
# graph
n = tvm.var('n')
m = tvm.var('m')
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(i>1)), name='B')
# schedule
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
fsum = tvm.build(s,
args=[A, B],
target=device, target_host=host,
name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel.
n = 1028
m = 129
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=1)
res[:2] = 0
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
freduce = tvm.build(s,
args=[A, B],
target=device, target_host=host,
name="myreduce")
print(freduce.imported_modules[0].get_source())
# launch the kernel.
n = 1028
m = 129
x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
freduce(x, y)
npy = y.asnumpy()
npy[:2] = 0
res = np_reducer(x.asnumpy(), axis=1)
res[:2] = 0
np.testing.assert_allclose(npy, res, rtol=1e-4)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda")
check_device("opencl")
test_prim(tvm.sum, np.sum)
test_prim(tvm.min, np.amin)
test_prim(tvm.max, np.amax)
check_device("cuda")
check_device("opencl")
def test_rfactor():
......@@ -127,4 +133,4 @@ def test_rfactor_threads():
if __name__ == "__main__":
test_rfactor_threads()
test_rfactor()
test_sum()
test_reduce_prims()
......@@ -33,6 +33,20 @@ def test_tensor_slice():
B = tvm.compute((n,), lambda i: A[0][i] + A[0][i])
def test_tensor_comm_reducer():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
k = tvm.reduce_axis((0, n), "k")
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k))
def test_tensor_comm_reducer_overload():
m = tvm.var('m')
n = tvm.var('n')
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
sum_res = mysum(m, n)
def test_tensor_reduce():
m = tvm.var('m')
n = tvm.var('n')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment