Commit db4be63c by Tianqi Chen Committed by GitHub

[TOPI] Numpy consistency: always broadcast binary op. (#1321)

parent 90db723d
...@@ -1934,7 +1934,7 @@ ENABLE_PREPROCESSING = YES ...@@ -1934,7 +1934,7 @@ ENABLE_PREPROCESSING = YES
# The default value is: NO. # The default value is: NO.
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES. # This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
MACRO_EXPANSION = NO MACRO_EXPANSION = YES
# If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES then # If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES then
# the macro expansion is limited to the macros specified with the PREDEFINED and # the macro expansion is limited to the macros specified with the PREDEFINED and
......
...@@ -31,8 +31,6 @@ List of operators ...@@ -31,8 +31,6 @@ List of operators
topi.take topi.take
topi.full topi.full
topi.full_like topi.full_like
topi.greater
topi.less
topi.nn.relu topi.nn.relu
topi.nn.leaky_relu topi.nn.leaky_relu
topi.nn.dilate topi.nn.dilate
...@@ -49,12 +47,16 @@ List of operators ...@@ -49,12 +47,16 @@ List of operators
topi.sum topi.sum
topi.min topi.min
topi.broadcast_to topi.broadcast_to
topi.broadcast_add topi.add
topi.broadcast_sub topi.subtract
topi.broadcast_mul topi.multiply
topi.broadcast_div topi.divide
topi.broadcast_maximum topi.mod
topi.broadcast_minimum topi.maximum
topi.minimum
topi.power
topi.greater
topi.less
topi.image.resize topi.image.resize
...@@ -94,19 +96,20 @@ topi ...@@ -94,19 +96,20 @@ topi
.. autofunction:: topi.take .. autofunction:: topi.take
.. autofunction:: topi.full .. autofunction:: topi.full
.. autofunction:: topi.full_like .. autofunction:: topi.full_like
.. autofunction:: topi.greater
.. autofunction:: topi.less
.. autofunction:: topi.max .. autofunction:: topi.max
.. autofunction:: topi.sum .. autofunction:: topi.sum
.. autofunction:: topi.min .. autofunction:: topi.min
.. autofunction:: topi.broadcast_to .. autofunction:: topi.broadcast_to
.. autofunction:: topi.broadcast_add .. autofunction:: topi.add
.. autofunction:: topi.broadcast_sub .. autofunction:: topi.subtract
.. autofunction:: topi.broadcast_mul .. autofunction:: topi.multiply
.. autofunction:: topi.broadcast_div .. autofunction:: topi.divide
.. autofunction:: topi.broadcast_maximum .. autofunction:: topi.mod
.. autofunction:: topi.broadcast_minimum .. autofunction:: topi.maximum
.. autofunction:: topi.minimum
.. autofunction:: topi.power
.. autofunction:: topi.greater
.. autofunction:: topi.less
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -105,7 +105,7 @@ def compute_conv2d(attrs, inputs, _): ...@@ -105,7 +105,7 @@ def compute_conv2d(attrs, inputs, _):
bias = inputs[2] bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0 expand_axis = 1 if layout == "NCHW" else 0
bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2) bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2)
out = topi.broadcast_add(out, bias) out = topi.add(out, bias)
return out return out
@reg.register_schedule("conv2d") @reg.register_schedule("conv2d")
...@@ -146,7 +146,7 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -146,7 +146,7 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
if attrs.get_bool("use_bias"): if attrs.get_bool("use_bias"):
bias = inputs[2] bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2) bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias) out = topi.add(out, bias)
return out return out
@reg.register_schedule("_contrib_conv2d_NCHWc") @reg.register_schedule("_contrib_conv2d_NCHWc")
...@@ -181,7 +181,7 @@ def compute_conv2d_transpose(attrs, inputs, _): ...@@ -181,7 +181,7 @@ def compute_conv2d_transpose(attrs, inputs, _):
if attrs.get_bool("use_bias"): if attrs.get_bool("use_bias"):
bias = inputs[2] bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2) bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias) out = topi.add(out, bias)
output_padding = attrs.get_int_tuple("output_padding") output_padding = attrs.get_int_tuple("output_padding")
out = topi.nn.pad(out, \ out = topi.nn.pad(out, \
[0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
......
...@@ -244,7 +244,7 @@ reg.register_schedule("ones_like", _fschedule_elemwise) ...@@ -244,7 +244,7 @@ reg.register_schedule("ones_like", _fschedule_elemwise)
@reg.register_compute("greater") @reg.register_compute("greater")
def compute_greater(_, inputs, out_info): def compute_greater(_, inputs, out_info):
"""Compute definition of greater""" """Compute definition of greater"""
return topi.tensor.greater(inputs[0], inputs[1], 'float32') return topi.greater(inputs[0], inputs[1]).astype('float32')
reg.register_pattern("greater", OpPattern.ELEMWISE) reg.register_pattern("greater", OpPattern.ELEMWISE)
reg.register_schedule("greater", _fschedule_elemwise) reg.register_schedule("greater", _fschedule_elemwise)
...@@ -252,7 +252,7 @@ reg.register_schedule("greater", _fschedule_elemwise) ...@@ -252,7 +252,7 @@ reg.register_schedule("greater", _fschedule_elemwise)
@reg.register_compute("less") @reg.register_compute("less")
def compute_less(_, inputs, out_info): def compute_less(_, inputs, out_info):
"""Compute definition of less""" """Compute definition of less"""
return topi.tensor.less(inputs[0], inputs[1], 'float32') return topi.less(inputs[0], inputs[1]).astype('float32')
reg.register_pattern("less", OpPattern.ELEMWISE) reg.register_pattern("less", OpPattern.ELEMWISE)
reg.register_schedule("less", _fschedule_elemwise) reg.register_schedule("less", _fschedule_elemwise)
......
...@@ -200,7 +200,7 @@ inline bool BinaryBroadcastCorrectLayout(const NodeAttrs& attrs, ...@@ -200,7 +200,7 @@ inline bool BinaryBroadcastCorrectLayout(const NodeAttrs& attrs,
return true; return true;
} }
#define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \ #define NNVM_REGISTER_BINARY_BROADCAST_OP(name, TOPIOp) \
NNVM_REGISTER_OP(name) \ NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \ .set_num_inputs(2) \
.set_num_outputs(1) \ .set_num_outputs(1) \
...@@ -217,13 +217,13 @@ inline bool BinaryBroadcastCorrectLayout(const NodeAttrs& attrs, ...@@ -217,13 +217,13 @@ inline bool BinaryBroadcastCorrectLayout(const NodeAttrs& attrs,
const Array<Tensor>& inputs, \ const Array<Tensor>& inputs, \
const Array<Tensor>& out_info) { \ const Array<Tensor>& out_info) { \
return Array<Tensor>{ \ return Array<Tensor>{ \
topi::name(inputs[0], inputs[1]) }; \ topi::TOPIOp(inputs[0], inputs[1]) }; \
}) \ }) \
.add_argument("lhs", "Tensor", "first input") \ .add_argument("lhs", "Tensor", "first input") \
.add_argument("rhs", "Tensor", "second input") .add_argument("rhs", "Tensor", "second input")
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_add) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_add, add)
.add_alias("__add_symbol__") .add_alias("__add_symbol__")
.describe(R"code(Returns element-wise sum of the input arrays with broadcasting. .describe(R"code(Returns element-wise sum of the input arrays with broadcasting.
...@@ -241,7 +241,7 @@ Example:: ...@@ -241,7 +241,7 @@ Example::
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub, subtract)
.add_alias("__sub_symbol__") .add_alias("__sub_symbol__")
.describe(R"code(Returns element-wise difference of the input arrays with broadcasting. .describe(R"code(Returns element-wise difference of the input arrays with broadcasting.
...@@ -259,7 +259,7 @@ Example:: ...@@ -259,7 +259,7 @@ Example::
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul, multiply)
.add_alias("__mul_symbol__") .add_alias("__mul_symbol__")
.describe(R"code(Returns element-wise product of the input arrays with broadcasting. .describe(R"code(Returns element-wise product of the input arrays with broadcasting.
...@@ -276,7 +276,7 @@ Example:: ...@@ -276,7 +276,7 @@ Example::
)code" NNVM_ADD_FILELINE); )code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_div) NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_div, divide)
.add_alias("__div_symbol__") .add_alias("__div_symbol__")
.describe(R"code(Returns element-wise division of the input arrays with broadcasting. .describe(R"code(Returns element-wise division of the input arrays with broadcasting.
......
...@@ -230,7 +230,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) ...@@ -230,7 +230,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_add(inputs[0], inputs[1]) }; return Array<Tensor>{ topi::add(inputs[0], inputs[1]) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -253,7 +253,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub) ...@@ -253,7 +253,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_sub(inputs[0], inputs[1]) }; return Array<Tensor>{ topi::subtract(inputs[0], inputs[1]) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -276,7 +276,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul) ...@@ -276,7 +276,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_mul)
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_mul(inputs[0], inputs[1]) }; return Array<Tensor>{ topi::multiply(inputs[0], inputs[1]) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -301,7 +301,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div) ...@@ -301,7 +301,7 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::broadcast_div(inputs[0], inputs[1]) }; return Array<Tensor>{ topi::divide(inputs[0], inputs[1]) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
......
...@@ -137,7 +137,7 @@ class ExprOp(object): ...@@ -137,7 +137,7 @@ class ExprOp(object):
expr : Expr expr : Expr
Expression with new type Expression with new type
""" """
return _make.static_cast(dtype, self) return _generic.cast(self, dtype)
class EqualOp(NodeGeneric, ExprOp): class EqualOp(NodeGeneric, ExprOp):
......
...@@ -79,3 +79,19 @@ def divide(lhs, rhs): ...@@ -79,3 +79,19 @@ def divide(lhs, rhs):
The result Expr of divide operaton. The result Expr of divide operaton.
""" """
return _make.Div(lhs, rhs) return _make.Div(lhs, rhs)
def cast(src, dtype):
"""Generic cast operator.
Parameters
----------
src : object
The source operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.static_cast(dtype, src)
...@@ -15,11 +15,11 @@ def test_operator_type_and_tags(): ...@@ -15,11 +15,11 @@ def test_operator_type_and_tags():
assert isinstance(k + n, tvm.expr.Expr) assert isinstance(k + n, tvm.expr.Expr)
assert isinstance(n + n, tvm.expr.Expr) assert isinstance(n + n, tvm.expr.Expr)
assert isinstance(k + A, tvm.expr.Expr) assert isinstance(k + A, tvm.tensor.Tensor)
assert isinstance(A + k, tvm.expr.Expr) assert isinstance(A + k, tvm.tensor.Tensor)
assert isinstance(n + A, tvm.expr.Expr) assert isinstance(n + A, tvm.tensor.Tensor)
assert isinstance(A + n, tvm.expr.Expr) assert isinstance(A + n, tvm.tensor.Tensor)
assert isinstance(A + A, tvm.expr.Expr) assert isinstance(A + A, tvm.tensor.Tensor)
assert isinstance(k + B, tvm.tensor.Tensor) assert isinstance(k + B, tvm.tensor.Tensor)
assert isinstance(B + k, tvm.tensor.Tensor) assert isinstance(B + k, tvm.tensor.Tensor)
...@@ -33,8 +33,8 @@ def test_operator_type_and_tags(): ...@@ -33,8 +33,8 @@ def test_operator_type_and_tags():
assert (B + k).op.tag == topi.tag.ELEMWISE assert (B + k).op.tag == topi.tag.ELEMWISE
assert (n + B).op.tag == topi.tag.ELEMWISE assert (n + B).op.tag == topi.tag.ELEMWISE
assert (B + n).op.tag == topi.tag.ELEMWISE assert (B + n).op.tag == topi.tag.ELEMWISE
assert (A + B).op.tag == topi.tag.ELEMWISE assert (A + B).op.tag == topi.tag.BROADCAST
assert (B + A).op.tag == topi.tag.ELEMWISE assert (B + A).op.tag == topi.tag.BROADCAST
assert (B + B).op.tag == topi.tag.BROADCAST assert (B + B).op.tag == topi.tag.BROADCAST
assert isinstance(k + B2, tvm.expr.Expr) assert isinstance(k + B2, tvm.expr.Expr)
...@@ -42,8 +42,8 @@ def test_operator_type_and_tags(): ...@@ -42,8 +42,8 @@ def test_operator_type_and_tags():
assert isinstance(n + B2, tvm.expr.Expr) assert isinstance(n + B2, tvm.expr.Expr)
assert isinstance(B2 + n, tvm.expr.Expr) assert isinstance(B2 + n, tvm.expr.Expr)
assert isinstance(B2 + B2, tvm.expr.Expr) assert isinstance(B2 + B2, tvm.expr.Expr)
assert isinstance(B2 + A, tvm.expr.Expr) assert isinstance(B2 + A, tvm.tensor.Tensor)
assert isinstance(A + B2, tvm.expr.Expr) assert isinstance(A + B2, tvm.tensor.Tensor)
assert isinstance(B2 + B, tvm.tensor.Tensor) assert isinstance(B2 + B, tvm.tensor.Tensor)
assert isinstance(B + B2, tvm.tensor.Tensor) assert isinstance(B + B2, tvm.tensor.Tensor)
...@@ -246,4 +246,4 @@ if __name__ == "__main__": ...@@ -246,4 +246,4 @@ if __name__ == "__main__":
test_combination() test_combination()
test_tensor_scalar_bop() test_tensor_scalar_bop()
test_broadcast_bop() test_broadcast_bop()
test_conv2d_scalar_bop() test_conv2d_scalar_bop()
\ No newline at end of file
...@@ -69,55 +69,6 @@ inline Tensor negative(const Tensor& x, ...@@ -69,55 +69,6 @@ inline Tensor negative(const Tensor& x,
} }
/*! /*!
* \brief Creates an operation that raises each element of tensor x to power y
*
* \param x The input tensor
* \param y The exponent
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the pow operation
*/
inline Tensor pow(const Tensor& x,
const Expr& y,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return tvm::pow(x(i), y);
}, name, tag);
}
/*!
* \brief Creates an operation that performs pointwise left shift by n bits
*
* \param x The input tensor
* \param n The number of bits to shift by
*
* \return A Tensor whose op member is the left shift operation
*/
inline Tensor operator<<(const Tensor& x,
const Expr& n) {
return compute(x->shape, [&](const Array<Var>& i) {
return x(i) << n;
}, "tensor", kElementWise);
}
/*!
* \brief Creates an operation that performs pointwise right shift by n bits
*
* \param x The input tensor
* \param n The number of bits to shift by
*
* \return A Tensor whose op member is the right shift operation
*/
inline Tensor operator>>(const Tensor& x,
const Expr& n) {
return compute(x->shape, [&](const Array<Var>& i) {
return x(i) >> n;
}, "tensor", kElementWise);
}
/*!
* \brief Creates an operation that clips each element of a tensor to * \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max] * the interval [a_min, a_max]
* *
......
...@@ -26,20 +26,20 @@ using namespace tvm; ...@@ -26,20 +26,20 @@ using namespace tvm;
* \return A Tensor whose op member is the l2 normalization operation * \return A Tensor whose op member is the l2 normalization operation
*/ */
inline Tensor l2_normalize(const Tensor& data, inline Tensor l2_normalize(const Tensor& data,
float eps, float eps,
const Array<Expr>& axis, const Array<Expr>& axis,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "l2_normalize") { std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input"; CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
auto input_shape = data->shape; auto input_shape = data->shape;
Tensor dot_value = pow(data, static_cast<float>(2.0)); Tensor dot_value = topi::power(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true); Tensor sum_value = topi::sum(dot_value, axis, true);
Tensor expand_sum = topi::broadcast_to(sum_value, input_shape); Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
return topi::broadcast_div(data, return topi::divide(data,
topi::sqrt(tvm::compute(expand_sum->shape, topi::sqrt(tvm::compute(expand_sum->shape,
[&](const Array<Var>& i){ [&](const Array<Var>& i){
return (max(expand_sum(i), eps)); return (max(expand_sum(i), eps));
}, name = name, tag = tag))); }, name = name, tag = tag)));
} }
} // namespace nn } // namespace nn
} // namespace topi } // namespace topi
......
...@@ -63,13 +63,14 @@ inline Tensor lrn(const Tensor& data, ...@@ -63,13 +63,14 @@ inline Tensor lrn(const Tensor& data,
{rxs}); {rxs});
}); });
} }
auto sqrt_sum_up = tvm::compute(input_shape, auto sqrt_sum_up = tvm::compute(
[&](Var i, Var j, Var k, Var l) { input_shape,
return tvm::pow(bias + [&](Var i, Var j, Var k, Var l) {
(alpha * sqr_sum(i, j, k, l) / size), return tvm::pow(bias +
beta); (alpha * sqr_sum(i, j, k, l) / size),
}); beta);
return topi::broadcast_div(data, sqrt_sum_up); });
return topi::divide(data, sqrt_sum_up);
} }
} // namespace nn } // namespace nn
} // namespace topi } // namespace topi
......
...@@ -5,7 +5,6 @@ import ctypes ...@@ -5,7 +5,6 @@ import ctypes
from imp import new_module as _new_module from imp import new_module as _new_module
from tvm._ffi.function import _init_api_prefix from tvm._ffi.function import _init_api_prefix
from tvm._ffi import libinfo from tvm._ffi import libinfo
import tvm as _tvm
def _get_lib_names(): def _get_lib_names():
if sys.platform.startswith('win32'): if sys.platform.startswith('win32'):
...@@ -35,7 +34,6 @@ def _create_module(name): ...@@ -35,7 +34,6 @@ def _create_module(name):
return mod return mod
# pylint: disable-msg=C0103 # pylint: disable-msg=C0103
nn = _create_module("nn") nn = _create_module("nn")
_init_api_prefix("topi.cpp.nn", "topi.nn") _init_api_prefix("topi.cpp.nn", "topi.nn")
generic = _create_module("generic") generic = _create_module("generic")
...@@ -52,41 +50,3 @@ yolo2 = _create_module("vision.yolo2") ...@@ -52,41 +50,3 @@ yolo2 = _create_module("vision.yolo2")
_init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2") _init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2")
image = _create_module("image") image = _create_module("image")
_init_api_prefix("topi.cpp.image", "topi.image") _init_api_prefix("topi.cpp.image", "topi.image")
class IntVector(object):
"""Handle to std::vector<int> instance """
_tvm_tcode = 27
def __init__(self, handle):
self.handle = handle
def __del__(self):
_tvm.nd.free_extension_handle(self.handle, 27)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
_tvm.register_extension(IntVector, IntVector)
class Target(object):
"""Handle to C++ Target instance """
_tvm_tcode = 28
def __init__(self, handle):
self.handle = handle
def __del__(self):
_tvm.nd.free_extension_handle(self.handle, 28)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
_tvm.register_extension(Target, Target)
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import broadcast as _broadcast from . import broadcast as _broadcast
from . import tag from . import math as _math
def _make_bop(elementwise_bop, broadcast_bop, orig_bop): def _make_bop(broadcast_bop, orig_bop):
"""Make a specific overloaded binary operator of Tensor when applicable; """Make a specific overloaded binary operator of Tensor when applicable;
apply the original operator if it is not supposed to be overloaded. apply the original operator if it is not supposed to be overloaded.
...@@ -23,9 +23,6 @@ def _make_bop(elementwise_bop, broadcast_bop, orig_bop): ...@@ -23,9 +23,6 @@ def _make_bop(elementwise_bop, broadcast_bop, orig_bop):
Parameters Parameters
---------- ----------
elementwise_bop : operator function
Operator for element-wise tensor-scalar operation, for rule (2).
broadcast_bop : operator function broadcast_bop : operator function
Operator for broadcast tensor-tensor operation, for rule (1). Operator for broadcast tensor-tensor operation, for rule (1).
...@@ -66,36 +63,9 @@ def _make_bop(elementwise_bop, broadcast_bop, orig_bop): ...@@ -66,36 +63,9 @@ def _make_bop(elementwise_bop, broadcast_bop, orig_bop):
tvm.Expr (otherwise) tvm.Expr (otherwise)
The result of {op} operation. The result of {op} operation.
""" """
if not isinstance(lhs, tvm.tensor.Tensor) and not isinstance(rhs, tvm.tensor.Tensor):
def _get_rank(x):
"""Get the rank of a value.
If x is Tensor, then return its rank;
if x is scalar_like (i.e., numeric types, Expr, or TensorSlice), return 0;
otherwise, return -1.
"""
if isinstance(x, tvm.tensor.Tensor):
return len(x.shape)
elif isinstance(x, (int, float, tvm.expr.Expr, tvm.tensor.TensorSlice)):
return 0
return -1
rl = _get_rank(lhs)
rr = _get_rank(rhs)
if rl == -1 or rr == -1 or (rl == 0 and rr == 0):
return orig_bop(lhs, rhs) return orig_bop(lhs, rhs)
elif rl > 0 and rr > 0: return broadcast_bop(lhs, rhs)
return broadcast_bop(lhs, rhs)
elif rl == 0:
f = lambda *i: elementwise_bop(lhs, rhs(*i))
with tvm.tag_scope(tag=tag.ELEMWISE):
return tvm.compute(rhs.shape, f, "tensor_" + name)
elif rr == 0:
f = lambda *i: elementwise_bop(lhs(*i), rhs)
with tvm.tag_scope(tag=tag.ELEMWISE):
return tvm.compute(lhs.shape, f, "tensor_" + name)
else:
raise AssertionError("Cannot reach this line.")
_tensor_bop_impl.__doc__ = _tensor_bop_impl.__doc__.format(op=name) _tensor_bop_impl.__doc__ = _tensor_bop_impl.__doc__.format(op=name)
return _tensor_bop_impl return _tensor_bop_impl
...@@ -106,18 +76,10 @@ def _bind_generic_ops(): ...@@ -106,18 +76,10 @@ def _bind_generic_ops():
__op_priority__ = 1 __op_priority__ = 1
if __op_priority__ > tvm.generic.__op_priority__: if __op_priority__ > tvm.generic.__op_priority__:
tvm.generic.__op_priority__ = __op_priority__ tvm.generic.__op_priority__ = __op_priority__
tvm.generic.add = _make_bop(lambda x, y: x + y, tvm.generic.add = _make_bop(_broadcast.add, tvm.generic.add)
_broadcast.broadcast_add, tvm.generic.subtract = _make_bop(_broadcast.subtract, tvm.generic.subtract)
tvm.generic.add) tvm.generic.multiply = _make_bop(_broadcast.multiply, tvm.generic.multiply)
tvm.generic.subtract = _make_bop(lambda x, y: x - y, tvm.generic.divide = _make_bop(_broadcast.divide, tvm.generic.divide)
_broadcast.broadcast_sub, tvm.generic.cast = _math.cast
tvm.generic.subtract)
tvm.generic.multiply = _make_bop(lambda x, y: x * y,
_broadcast.broadcast_mul,
tvm.generic.multiply)
tvm.generic.divide = _make_bop(lambda x, y: x / y,
_broadcast.broadcast_div,
tvm.generic.divide)
_bind_generic_ops() _bind_generic_ops()
...@@ -258,14 +258,14 @@ def clip(x, a_min, a_max): ...@@ -258,14 +258,14 @@ def clip(x, a_min, a_max):
return tvm.compute(x.shape, _compute) return tvm.compute(x.shape, _compute)
@tvm.tag_scope(tag=tag.ELEMWISE)
def cast(x, dtype): def cast(x, dtype):
"""Cast input to specified data type. """Cast input to specified data type.
Parameters Parameters
---------- ----------
x : tvm.Tensor x : tvm.Tensor or Expr
Input argument. Input argument.
dtype : str dtype : str
Data type. Data type.
...@@ -274,4 +274,7 @@ def cast(x, dtype): ...@@ -274,4 +274,7 @@ def cast(x, dtype):
y : tvm.Tensor y : tvm.Tensor
The result. The result.
""" """
return tvm.compute(x.shape, lambda *i: x(*i).astype(dtype)) if isinstance(x, tvm.tensor.Tensor):
return tvm.compute(
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make.static_cast(dtype, x)
...@@ -68,52 +68,3 @@ def full_like(x, fill_value): ...@@ -68,52 +68,3 @@ def full_like(x, fill_value):
""" """
dtype = x.dtype dtype = x.dtype
return tvm.compute(x.shape, lambda *i: tvm.const(fill_value, dtype)) return tvm.compute(x.shape, lambda *i: tvm.const(fill_value, dtype))
@tvm.tag_scope(tag=tag.ELEMWISE)
def greater(lhs, rhs, out_type=tvm.int8):
"""Compare two input tensors element-wise and return an mask tensor
which contains 1 if lhs > rhs holds else 0
Parameters
----------
lhs : tvm.Tensor
Left input argument.
rhs : tvm.Tensor
Right argument.
out_type: str
Output data type. Default is int8
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(lhs.shape,
lambda *i: tvm.select(lhs(*i) > rhs(*i),
tvm.const(1, out_type),
tvm.const(0, out_type)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def less(lhs, rhs, out_type=tvm.int8):
"""Compare two input tensors element-wise and return an mask tensor
which contains 1 if lhs < rhs holds else 0
Parameters
----------
lhs : tvm.Tensor
Left input argument.
rhs : tvm.Tensor
Right argument.
out_type: str
Output data type. Default is int8
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(lhs.shape,
lambda *i: tvm.select(lhs(*i) < rhs(*i),
tvm.const(1, out_type),
tvm.const(0, out_type)))
...@@ -67,51 +67,55 @@ Array<Expr> ArrayOrInt(TVMArgValue arg) { ...@@ -67,51 +67,55 @@ Array<Expr> ArrayOrInt(TVMArgValue arg) {
} }
} }
inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kNodeHandle &&
arg.node_sptr()->is_type<tvm::TensorNode>());
}
TVM_REGISTER_GLOBAL("topi.TEST_create_target") TVM_REGISTER_GLOBAL("topi.TEST_create_target")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tvm::Target::create(args[0]); *rv = tvm::Target::create(args[0]);
}); });
/* Ops from broadcast.h */ /* Ops from broadcast.h */
#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
TVM_REGISTER_GLOBAL(OpName) \
.set_body([](TVMArgs args, TVMRetValue *rv) { \
bool lhs_is_tensor = IsTensorType(args[0]); \
bool rhs_is_tensor = IsTensorType(args[1]); \
if (lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::Tensor(), \
args[1].operator tvm::Tensor()); \
} else if (!lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::Expr(), \
args[1].operator tvm::Tensor()); \
} else if (lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::Tensor(), \
args[1].operator tvm::Expr()); \
} else if (!lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::Expr(), \
args[1].operator tvm::Expr()); \
} \
}); \
TVM_REGISTER_GLOBAL("topi.broadcast_to") TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]); *rv = broadcast_to(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.broadcast_add") TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
.set_body([](TVMArgs args, TVMRetValue *rv) { TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
*rv = broadcast_add(args[0], args[1]); TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
}); TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
TVM_REGISTER_GLOBAL("topi.broadcast_sub") TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
.set_body([](TVMArgs args, TVMRetValue *rv) { TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
*rv = broadcast_sub(args[0], args[1]); TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
}); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TVM_REGISTER_GLOBAL("topi.broadcast_mul") TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
.set_body([](TVMArgs args, TVMRetValue *rv) { TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
*rv = broadcast_mul(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.broadcast_div")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_div(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.broadcast_maximum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_maximum(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.broadcast_minimum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_minimum(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.broadcast_pow")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_pow(args[0], args[1]);
});
/* Ops from elemwise.h */ /* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp") TVM_REGISTER_GLOBAL("topi.exp")
...@@ -149,25 +153,6 @@ TVM_REGISTER_GLOBAL("topi.negative") ...@@ -149,25 +153,6 @@ TVM_REGISTER_GLOBAL("topi.negative")
*rv = negative(args[0]); *rv = negative(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.pow")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = pow(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.left_shift")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor lhs = args[0];
Expr rhs = args[1];
*rv = lhs >> rhs;
});
TVM_REGISTER_GLOBAL("topi.right_shift")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor lhs = args[0];
Expr rhs = args[1];
*rv = lhs << rhs;
});
TVM_REGISTER_GLOBAL("topi.clip") TVM_REGISTER_GLOBAL("topi.clip")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = clip(args[0], args[1], args[2]); *rv = clip(args[0], args[1], args[2]);
......
...@@ -4,10 +4,10 @@ import numpy as np ...@@ -4,10 +4,10 @@ import numpy as np
import tvm import tvm
import topi import topi
def verify_broadcast_to_ele(in_shape, out_shape): def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape) B = fbcast(A, out_shape)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -32,24 +32,20 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -32,24 +32,20 @@ def verify_broadcast_to_ele(in_shape, out_shape):
check_device("rocm") check_device("rocm")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
ftopi, fnumpy,
lhs_min=-100, lhs_max=100,
rhs_min=-100, rhs_max=100,
dtype="float32"):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=lhs_shape, name="A") A = (tvm.var("A", dtype=dtype) if lhs_shape is None
B = tvm.placeholder(shape=rhs_shape, name="B") else tvm.placeholder(shape=lhs_shape, name="A", dtype=dtype))
if typ == "add": B = (tvm.var("B", dtype=dtype) if rhs_shape is None
C = topi.broadcast_add(A, B) else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
elif typ == "sub": C = ftopi(A, B)
C = topi.broadcast_sub(A, B) if (isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr)):
elif typ == "div": assert(isinstance(C, tvm.expr.Expr))
C = topi.broadcast_div(A, B) return
elif typ == "mul":
C = topi.broadcast_mul(A, B)
elif typ == "maximum":
C = topi.broadcast_maximum(A, B)
elif typ == "minimum":
C = topi.broadcast_minimum(A, B)
else:
raise NotImplementedError
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -58,54 +54,102 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -58,54 +54,102 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C) s = topi.generic.schedule_broadcast(C)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ) foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + ftopi.__name__)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) if lhs_shape is None:
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype) lhs_npy = float(np.random.uniform(low=lhs_min, high=lhs_max))
if typ == "add": if dtype.startswith('int'):
out_npy = lhs_npy + rhs_npy lhs_npy = int(lhs_npy)
elif typ == "sub": lhs_nd = lhs_npy
out_npy = lhs_npy - rhs_npy
elif typ == "div":
rhs_npy = np.abs(rhs_npy) + 0.001
out_npy = lhs_npy / rhs_npy
elif typ == "mul":
out_npy = lhs_npy * rhs_npy
elif typ == "maximum":
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
else: else:
raise NotImplementedError lhs_npy = np.random.uniform(low=lhs_min, high=lhs_max,
lhs_nd = tvm.nd.array(lhs_npy, ctx) size=lhs_shape).astype(A.dtype)
rhs_nd = tvm.nd.array(rhs_npy, ctx) lhs_nd = tvm.nd.array(lhs_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
if rhs_shape is None:
rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
if dtype.startswith('int'):
lhs_npy = int(lhs_npy)
rhs_nd = rhs_npy
else:
rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
size=rhs_shape).astype(A.dtype)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
for _ in range(1): for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd) foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("vulkan")
check_device("opencl") check_device("opencl")
check_device("vulkan")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def test_broadcast_to(): def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,)) verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
verify_broadcast_to_ele((), (10,)) verify_broadcast_to_ele((), (10,), topi.broadcast_to)
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4)) verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32)) verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)
def test_add():
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)
def test_subtract():
verify_broadcast_binary_ele(
(5, 2, 3), (), topi.subtract, np.subtract)
verify_broadcast_binary_ele(
(5, 2, 3), None, topi.subtract, np.subtract)
verify_broadcast_binary_ele(
None, None, topi.subtract, np.subtract)
verify_broadcast_binary_ele(
(1, 32), (64, 32), topi.subtract, np.subtract)
def test_multiply():
verify_broadcast_binary_ele(
(5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_maximum_minmum():
verify_broadcast_binary_ele(
(32,), (64, 32), topi.maximum, np.maximum)
verify_broadcast_binary_ele(
(1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
def test_power():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
def test_mod():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
def test_broadcast_binary(): def test_cmp():
verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add") # explicit specify the output type
verify_broadcast_binary_ele((5, 2, 3), (), typ="add") def greater(x, y):
verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul") return topi.greater(x, y).astype("int8")
verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div") def less(x, y):
verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub") return topi.less(x, y).astype("int8")
verify_broadcast_binary_ele((32,), (64, 32), typ="maximum") verify_broadcast_binary_ele(
verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), typ="minimum") (1, 2, 2), (2,), greater, np.greater)
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), less, np.less)
if __name__ == "__main__": if __name__ == "__main__":
test_broadcast_binary() test_cmp()
test_mod()
test_add()
test_subtract()
test_multiply()
test_divide()
test_maximum_minmum()
test_power()
test_broadcast_to() test_broadcast_to()
...@@ -69,44 +69,6 @@ def verify_full(shape, dtype, fill_value): ...@@ -69,44 +69,6 @@ def verify_full(shape, dtype, fill_value):
check_device(device) check_device(device)
def verify_comparator(shape, dtype, out_type='int8'):
A = tvm.placeholder(shape, dtype, name="A")
B = tvm.placeholder(shape, dtype, name="B")
C = topi.less(A, B)
s_less = tvm.create_schedule([C.op])
D = tvm.placeholder(shape, dtype, name="D")
E = tvm.placeholder(shape, dtype, name="E")
F = topi.greater(D, E, out_type)
s_greater = tvm.create_schedule([F.op])
@memoize("topi.tests.test_topi_indicator")
def get_ref_data():
return [np.random.uniform(0, 10, size=shape).astype(dtype),
np.random.uniform(0, 10, size=shape).astype(dtype)]
[np_l, np_r] = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
out = tvm.nd.array(np.zeros(shape, dtype=out_type), ctx)
tvm_l = tvm.nd.array(np_l, ctx)
tvm_r = tvm.nd.array(np_r, ctx)
f = tvm.build(s_less, [A, B, C], device, name="less")
f(tvm_l, tvm_r, out)
np.testing.assert_allclose(out.asnumpy(), np.less(np_l, np_r).astype(out_type), rtol=1e-5)
f = tvm.build(s_greater, [D, E, F], device, name="greater")
f(tvm_l, tvm_r, out)
np.testing.assert_allclose(out.asnumpy(), np.greater(np_l, np_r).astype(out_type), rtol=1e-5)
for device in ["llvm"]:
check_device(device)
def test_elemwise_sum(): def test_elemwise_sum():
verify_elemwise_sum(1, "float32") verify_elemwise_sum(1, "float32")
verify_elemwise_sum(5, "float32") verify_elemwise_sum(5, "float32")
...@@ -118,12 +80,6 @@ def test_full(): ...@@ -118,12 +80,6 @@ def test_full():
verify_full((10,), "int32", 7) verify_full((10,), "int32", 7)
def test_comparator():
verify_comparator((3,4,5), "float32")
verify_comparator((7,), "int32")
verify_comparator((3,4,5), "float32", "int8")
if __name__ == "__main__": if __name__ == "__main__":
test_elemwise_sum() test_elemwise_sum()
test_full() test_full()
test_comparator()
"""Test code for broadcasting operators."""
import os
import numpy as np
import tvm
import topi
def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.broadcast_to(A, out_shape)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("opencl")
check_device("cuda")
#check_device("metal")
#check_device("rocm")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
# Build the logic and compile the function
A = tvm.placeholder(shape=lhs_shape, name="A")
B = tvm.placeholder(shape=rhs_shape, name="B")
if typ == "add":
C = topi.cpp.broadcast_add(A, B)
elif typ == "sub":
C = topi.cpp.broadcast_sub(A, B)
elif typ == "div":
C = topi.cpp.broadcast_div(A, B)
elif typ == "mul":
C = topi.cpp.broadcast_mul(A, B)
elif typ == "maximum":
C = topi.cpp.broadcast_maximum(A, B)
elif typ == "minimum":
C = topi.cpp.broadcast_minimum(A, B)
elif typ == "pow":
C = topi.cpp.broadcast_pow(A, B)
else:
raise NotImplementedError
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.cuda.schedule_injective(target, [C])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
if typ == "add":
out_npy = lhs_npy + rhs_npy
elif typ == "sub":
out_npy = lhs_npy - rhs_npy
elif typ == "div":
rhs_npy = np.abs(rhs_npy) + 0.001
out_npy = lhs_npy / rhs_npy
elif typ == "mul":
out_npy = lhs_npy * rhs_npy
elif typ == "maximum":
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
elif typ == "pow":
out_npy = lhs_npy ** rhs_npy
else:
raise NotImplementedError
lhs_nd = tvm.nd.array(lhs_npy, ctx)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("opencl")
check_device("cuda")
#check_device("metal")
#check_device("rocm")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,))
verify_broadcast_to_ele((), (10,))
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))
def test_broadcast_binary():
verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add")
verify_broadcast_binary_ele((5, 2, 3), (), typ="add")
verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul")
verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div")
verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub")
verify_broadcast_binary_ele((32,), (64, 32), typ="maximum")
verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), typ="minimum")
verify_broadcast_binary_ele((1, 32), (64, 32), typ="pow")
if __name__ == "__main__":
test_broadcast_to()
test_broadcast_binary()
...@@ -243,7 +243,7 @@ def verify_concatenate_broadcast(shapes, axis, rhs_shape): ...@@ -243,7 +243,7 @@ def verify_concatenate_broadcast(shapes, axis, rhs_shape):
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i))) tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.cpp.concatenate(tensor_l, axis) out_tensor = topi.cpp.concatenate(tensor_l, axis)
C = topi.cpp.broadcast_add(out_tensor, B) C = out_tensor + B
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment