Commit 50681784 by Tatsuya Nishiyama Committed by Tianqi Chen

[TOPI] Add C++ implementation of elementwise operators (#1306)

parent 44d8203f
...@@ -187,7 +187,7 @@ def compute_elemwise_sum(attrs, inputs, _): ...@@ -187,7 +187,7 @@ def compute_elemwise_sum(attrs, inputs, _):
"""Compute definition of elemwise sum""" """Compute definition of elemwise sum"""
num_args = attrs.get_int("num_args") num_args = attrs.get_int("num_args")
assert num_args == len(inputs), "Number of tensors does not match num_args." assert num_args == len(inputs), "Number of tensors does not match num_args."
return topi.tensor.elemwise_sum(inputs, num_args) return topi.tensor.elemwise_sum(inputs)
reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE) reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE)
reg.register_schedule("elemwise_sum", _fschedule_elemwise) reg.register_schedule("elemwise_sum", _fschedule_elemwise)
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "topi/tags.h" #include "topi/tags.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "tvm/ir.h"
#include "tvm/ir_pass.h"
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -122,5 +124,76 @@ inline Tensor cast(const Tensor& x, ...@@ -122,5 +124,76 @@ inline Tensor cast(const Tensor& x,
}, name, tag); }, name, tag);
} }
/*!
* \brief Creates an operation that sum each element of a tensor
*
* \param xs The input tensor array
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sum operation
*/
inline Tensor elemwise_sum(const Array<Tensor>& xs,
std::string name = "tensor",
std::string tag = kElementWise) {
CHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor.";
return compute(xs[0]->shape, [&](const Array<Var>& i) {
auto sum_expr = xs[0](i);
for (size_t j = 1; j < xs.size(); j++) {
sum_expr = sum_expr + xs[j](i);
}
return sum_expr;
}, name, tag);
}
/*!
* \brief Creates an operation that fill a tensor with fill_value
*
* \param shape The shape of a tensor
* \param dtype The Type of fill_value
* \param fill_value The value to be filled
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the full operation
*/
inline Tensor full(const Array<Expr>& shape,
Type dtype,
const Expr fill_value,
std::string name = "tensor",
std::string tag = kElementWise) {
Expr ev = lossless_cast(dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << dtype;
}
return compute(shape, [&](const Array<Var>& i) {
return ev;
}, name, tag);
}
/*!
* \brief Creates an operation that construct a tensor with same shape as input tensor,
* then fill a tensor with fill_value
*
* \param x The input tensor
* \param fill_value The value to be filled
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op memeber is the full_like operation
*/
inline Tensor full_like(const Tensor& x,
const Expr fill_value,
std::string name = "tensor",
std::string tag = kElementWise) {
Expr ev = lossless_cast(x->dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << x->dtype;
}
return compute(x->shape, [&](const Array<Var>& i) {
return ev;
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_ELEMWISE_H_ #endif // TOPI_ELEMWISE_H_
...@@ -2,30 +2,24 @@ ...@@ -2,30 +2,24 @@
"""Elementwise operators""" """Elementwise operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import cpp
from . import tag from . import tag
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def elemwise_sum(xs, num_args): def elemwise_sum(xs):
"""Perform element-wise sum on inputs """Perform element-wise sum on inputs
Parameters Parameters
---------- ----------
xs : list of tvm.Tensor xs : list of tvm.Tensor
Input arguments. Input arguments.
num_args : int
Number of arguments
Returns Returns
------- -------
y : tvm.Tensor y : tvm.Tensor
The result. The result.
""" """
assert len(xs) > 0, "elemwise sum must have at least one input tensor." return cpp.elemwise_sum(xs)
def _compute(*i):
return sum([x(*i) for x in xs])
return tvm.compute(xs[0].shape, _compute)
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
...@@ -46,7 +40,7 @@ def full(shape, dtype, fill_value): ...@@ -46,7 +40,7 @@ def full(shape, dtype, fill_value):
y : tvm.Tensor y : tvm.Tensor
The result. The result.
""" """
return tvm.compute(shape, lambda *i: tvm.const(fill_value, dtype)) return cpp.full(shape, dtype, fill_value)
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
...@@ -66,5 +60,4 @@ def full_like(x, fill_value): ...@@ -66,5 +60,4 @@ def full_like(x, fill_value):
y : tvm.Tensor y : tvm.Tensor
The result. The result.
""" """
dtype = x.dtype return cpp.full_like(x, fill_value)
return tvm.compute(x.shape, lambda *i: tvm.const(fill_value, dtype))
...@@ -163,6 +163,21 @@ TVM_REGISTER_GLOBAL("topi.cast") ...@@ -163,6 +163,21 @@ TVM_REGISTER_GLOBAL("topi.cast")
*rv = cast(args[0], args[1]); *rv = cast(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.elemwise_sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = elemwise_sum(args[0]);
});
TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.full_like")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full_like(args[0], args[1]);
});
/* Ops from nn.h */ /* Ops from nn.h */
TVM_REGISTER_GLOBAL("topi.nn.relu") TVM_REGISTER_GLOBAL("topi.nn.relu")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -11,7 +11,7 @@ def verify_elemwise_sum(num_args, dtype): ...@@ -11,7 +11,7 @@ def verify_elemwise_sum(num_args, dtype):
for i in range(num_args): for i in range(num_args):
tvm_placeholders.append( tvm_placeholders.append(
tvm.placeholder(shape, name="data"+str(i), dtype=dtype)) tvm.placeholder(shape, name="data"+str(i), dtype=dtype))
esum = topi.elemwise_sum(tvm_placeholders, num_args=num_args) esum = topi.elemwise_sum(tvm_placeholders)
s = tvm.create_schedule([esum.op]) s = tvm.create_schedule([esum.op])
@memoize("topi.tests.test_topi_elemwise_sum") @memoize("topi.tests.test_topi_elemwise_sum")
......
"""Test code for tensor operator"""
import numpy as np
import tvm
import topi
def verify_elemwise_sum(num_args, dtype):
shape = (3,5,4)
tvm_placeholders = []
for i in range(num_args):
tvm_placeholders.append(
tvm.placeholder(shape, name="data"+str(i), dtype=dtype))
esum = topi.cpp.elemwise_sum(tvm_placeholders)
s = tvm.create_schedule([esum.op])
def get_ref_data():
np_nd = [np.random.uniform(0, 10, size=shape).astype(dtype)
for i in range(num_args)]
return np_nd
np_nd = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
f = tvm.build(s, tvm_placeholders + [esum], device, name="elemwise_sum")
tvm_nd = [tvm.nd.array(nd, ctx) for nd in np_nd] + [out]
f(*tvm_nd)
np_out = np.sum(np.array(np_nd), axis=0)
np.testing.assert_allclose(out.asnumpy(), np_out, rtol=1e-5)
for device in ["llvm"]:
check_device(device)
def verify_full(shape, dtype, fill_value):
A = tvm.placeholder(shape, dtype=dtype, name="A")
B = topi.cpp.full_like(A, fill_value)
C = topi.cpp.full(shape, dtype, fill_value)
s1 = tvm.create_schedule([B.op])
s2 = tvm.create_schedule([C.op])
def get_ref_data():
return np.full(shape, fill_value, dtype)
np_nd = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
target = topi.cpp.TEST_create_target(device)
ctx = tvm.context(device, 0)
out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
f = tvm.build(s1, [A, B], device, name="full_like")
f(tvm.nd.array(np.zeros(shape, dtype), ctx), out)
np.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)
f = tvm.build(s2, [C], device, name="full")
f(out)
np.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)
for device in ["llvm"]:
check_device(device)
def test_elemwise_sum():
verify_elemwise_sum(1, "float32")
verify_elemwise_sum(5, "float32")
verify_elemwise_sum(4, "int32")
def test_full():
verify_full((3,4,5), "float32", 3.14)
verify_full((10,), "int32", 7)
if __name__ == "__main__":
test_elemwise_sum()
test_full()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment