Commit 2f5b155a by Haichen Shen Committed by Yizhi Liu

[Relay/TOPI][Op] Add erf intrinsic and op (#3702)

* add more ops

* stop vectorization for erf

* x

* cleanup

* fix

* add whitelist for vectorizable intrin

* add tf converter

* fix dense

* fix

* add missing intrin

* fix mxnet frontend

* fix nvptx
parent 6a377f77
...@@ -512,6 +512,7 @@ TVM_DLL Expr trunc(Expr x); ...@@ -512,6 +512,7 @@ TVM_DLL Expr trunc(Expr x);
} \ } \
TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(sqrt);
......
...@@ -556,6 +556,9 @@ class Call : public ExprNode { ...@@ -556,6 +556,9 @@ class Call : public ExprNode {
name == intrin_name); name == intrin_name);
} }
/*! \return Whether call node can be vectorized. */
bool is_vectorizable() const;
static constexpr const char* _type_key = "Call"; static constexpr const char* _type_key = "Call";
TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);
...@@ -571,6 +574,9 @@ class Call : public ExprNode { ...@@ -571,6 +574,9 @@ class Call : public ExprNode {
static constexpr const char* likely = "likely"; static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store"; static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch"; static constexpr const char* prefetch = "prefetch";
/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
}; };
/*! /*!
......
...@@ -211,6 +211,22 @@ def exp(x): ...@@ -211,6 +211,22 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x) return call_pure_intrin(x.dtype, "exp", x)
def erf(x):
"""Take gauss error function of the input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "erf", x)
def tanh(x): def tanh(x):
"""Take hyperbolic tanh of input x. """Take hyperbolic tanh of input x.
......
...@@ -170,8 +170,8 @@ class Executor(object): ...@@ -170,8 +170,8 @@ class Executor(object):
return args return args
if kwargs and not isinstance(expr, Function): if kwargs and not isinstance(expr, Function):
raise Exception("can only supply keyword parameters for a \ raise Exception("can only supply keyword parameters for a "
relay.Function, found {0}".format(expr)) "relay.Function, found {0}".format(expr))
params = expr.params params = expr.params
param_names = [p.name_hint for p in params] param_names = [p.name_hint for p in params]
...@@ -182,16 +182,16 @@ class Executor(object): ...@@ -182,16 +182,16 @@ class Executor(object):
if i < num_of_args: if i < num_of_args:
if kwargs.get(name): if kwargs.get(name):
raise Exception( raise Exception(
"duplicate argument supplied in \ "duplicate argument supplied in "
both positional args (at position: {0}), \ "both positional args (at position: {0}), "
and keyword argument (with name: {1})".format(i, name)) "and keyword argument (with name: {1})".format(i, name))
else: else:
cargs.append(kwargs[name]) cargs.append(kwargs[name])
if len(cargs) != len(params): if len(cargs) != len(params):
raise Exception( raise Exception(
"insufficient arguments, expected" \ "insufficient arguments, expected "
" {0}, provided {1}".format(len(cargs), len(params))) "{0}, provided {1}".format(len(cargs), len(params)))
return tuple(cargs) return tuple(cargs)
......
...@@ -124,7 +124,16 @@ class StrAttrsDict(object): ...@@ -124,7 +124,16 @@ class StrAttrsDict(object):
""" """
if key in self.attrs: if key in self.attrs:
tshape = self.attrs[key] tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()[]').split(',') if x) ret = []
for x in tshape.strip('()[]').split(','):
x = x.strip()
if not x:
continue
if x == "None":
ret.append(None)
else:
ret.append(int(x))
return tuple(ret)
if isinstance(default, RequiredAttr): if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return default return default
......
...@@ -55,10 +55,17 @@ def _mx_fully_connected(inputs, attrs): ...@@ -55,10 +55,17 @@ def _mx_fully_connected(inputs, attrs):
use_flatten = attrs.get_bool("flatten", True) use_flatten = attrs.get_bool("flatten", True)
if has_flatten and use_flatten: if has_flatten and use_flatten:
inputs[0] = _op.nn.batch_flatten(inputs[0]) inputs[0] = _op.nn.batch_flatten(inputs[0])
data_shape = _infer_type(inputs[0]).checked_type.shape
if len(data_shape) > 2:
inputs[0] = _op.reverse_reshape(inputs[0], [-1, 0])
res = _op.nn.dense(inputs[0], inputs[1], units=units) res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias: if use_bias:
assert len(inputs) == 3 assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=-1) res = _op.nn.bias_add(res, inputs[2], axis=-1)
if len(data_shape) > 2:
new_shape = data_shape[:-1]
new_shape.append(units)
res = _op.reshape(res, new_shape)
return res return res
...@@ -241,8 +248,8 @@ def _mx_layer_norm(inputs, attrs): ...@@ -241,8 +248,8 @@ def _mx_layer_norm(inputs, attrs):
def _mx_slice(inputs, attrs): def _mx_slice(inputs, attrs):
new_attrs = {} new_attrs = {}
begin = attrs.get_int_tuple('begin', None) begin = list(attrs.get_int_tuple('begin', None))
end = attrs.get_int_tuple('end', None) end = list(attrs.get_int_tuple('end', None))
stride = attrs.get_int_tuple('step', None) stride = attrs.get_int_tuple('step', None)
if begin is None: if begin is None:
raise tvm.error.OpAttributeRequired( raise tvm.error.OpAttributeRequired(
...@@ -251,11 +258,12 @@ def _mx_slice(inputs, attrs): ...@@ -251,11 +258,12 @@ def _mx_slice(inputs, attrs):
raise tvm.error.OpAttributeRequired( raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.') 'Attribute "end" not found in operator Slice.')
if None in begin: if None in begin:
raise tvm.error.OpAttributeInvalid( data_shape = _infer_type(inputs[0]).checked_type.shape
'Value None in attribute "begin" of operator Slice is not valid.') for i, beg in enumerate(begin):
if None in end: if beg is None:
raise tvm.error.OpAttributeInvalid( assert end[i] is None
'Value None in attribute "end" of operator Slice is not valid.') begin[i] = 0
end[i] = data_shape[i]
new_attrs = {'begin': begin, 'end': end} new_attrs = {'begin': begin, 'end': end}
if stride is not None: if stride is not None:
new_attrs['strides'] = stride new_attrs['strides'] = stride
...@@ -497,7 +505,8 @@ def _mx_arange(inputs, attrs): ...@@ -497,7 +505,8 @@ def _mx_arange(inputs, attrs):
'Attribute "repeat" is not supported in operator arange.') 'Attribute "repeat" is not supported in operator arange.')
new_attrs = {} new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0)) new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = _expr.const(attrs.get_float("stop")) stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0)) new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32") new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs) return _op.arange(**new_attrs)
...@@ -910,6 +919,7 @@ def _mx_one_hot(inputs, attrs): ...@@ -910,6 +919,7 @@ def _mx_one_hot(inputs, attrs):
_identity_list = [ _identity_list = [
"log", "log",
"exp", "exp",
"erf",
"sqrt", "sqrt",
"floor", "floor",
"ceil", "ceil",
......
...@@ -1261,6 +1261,7 @@ _convert_map = { ...@@ -1261,6 +1261,7 @@ _convert_map = {
'DepthToSpace' : _depth_to_space(), 'DepthToSpace' : _depth_to_space(),
'Equal' : _broadcast('equal'), 'Equal' : _broadcast('equal'),
'Elu' : _elu(), 'Elu' : _elu(),
'Erf' : AttrCvt('erf'),
'Exp' : AttrCvt('exp'), 'Exp' : AttrCvt('exp'),
'ExpandDims' : _expand_dims(), 'ExpandDims' : _expand_dims(),
'Fill' : _fill(), 'Fill' : _fill(),
......
...@@ -30,6 +30,7 @@ register_schedule("log1p", schedule_broadcast) ...@@ -30,6 +30,7 @@ register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast) register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast) register_schedule("sin", schedule_broadcast)
register_schedule("exp", schedule_broadcast) register_schedule("exp", schedule_broadcast)
register_schedule("erf", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast) register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast) register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast) register_schedule("sigmoid", schedule_broadcast)
......
...@@ -92,6 +92,22 @@ def exp(data): ...@@ -92,6 +92,22 @@ def exp(data):
return _make.exp(data) return _make.exp(data)
def erf(data):
"""Compute elementwise error function of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.erf(data)
def sqrt(data): def sqrt(data):
"""Compute elementwise sqrt of data. """Compute elementwise sqrt of data.
......
...@@ -31,6 +31,9 @@ namespace intrin { ...@@ -31,6 +31,9 @@ namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
......
...@@ -92,6 +92,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") ...@@ -92,6 +92,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
......
...@@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") ...@@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
......
...@@ -176,6 +176,22 @@ Expr Let::make(Var var, Expr value, Expr body) { ...@@ -176,6 +176,22 @@ Expr Let::make(Var var, Expr value, Expr body) {
return Expr(node); return Expr(node);
} }
const char* Call::vectorizable_intrinsics[] = {
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
"log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
ir::Call::likely, ir::Call::popcount
};
bool Call::is_vectorizable() const {
size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
for (size_t i = 0; i < cnt; ++i) {
if (name == Call::vectorizable_intrinsics[i]) {
return true;
}
}
return false;
}
Expr Call::make(DataType type, Expr Call::make(DataType type,
std::string name, std::string name,
Array<Expr> args, Array<Expr> args,
......
...@@ -268,16 +268,34 @@ class Vectorizer : public IRMutator { ...@@ -268,16 +268,34 @@ class Vectorizer : public IRMutator {
if (op->name == intrinsic::tvm_if_then_else) { if (op->name == intrinsic::tvm_if_then_else) {
return MutateIfThenElseExpr_(op, e); return MutateIfThenElseExpr_(op, e);
} }
int lane = 0; if (!op->is_vectorizable()) {
Array<Expr> new_args = MutateArray(op->args, &lane); // Cannot vectorize this op
Array<Expr> new_args;
// normal code path. for (auto arg : op->args) {
if (op->args.same_as(new_args)) { auto new_arg = this->Mutate(arg);
return e; if (new_arg.type().is_vector()) {
need_scalarize_ = true;
return e;
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type, op->name, new_args, op->call_type, op->func, op->value_index);
}
} else { } else {
return Call::make( int lane = 0;
op->type.with_lanes(lane), op->name, new_args, Array<Expr> new_args = MutateArray(op->args, &lane);
op->call_type, op->func, op->value_index); // normal code path.
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
}
} }
} }
// Load // Load
......
...@@ -85,6 +85,18 @@ RELAY_REGISTER_UNARY_OP("exp") ...@@ -85,6 +85,18 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\erf(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
RELAY_REGISTER_UNARY_OP("sqrt") RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise. .describe(R"code(Returns the sqrt input array, computed element-wise.
......
...@@ -1844,6 +1844,14 @@ def test_forward_zeros_like(): ...@@ -1844,6 +1844,14 @@ def test_forward_zeros_like():
_test_forward_zeros_like((2, 3, 11), "float32") _test_forward_zeros_like((2, 3, 11), "float32")
_test_forward_zeros_like((2, 3, 11), "float64") _test_forward_zeros_like((2, 3, 11), "float64")
def test_forward_erf():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')
def _test_forward_reverse_v2(in_shape, axis, dtype): def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph() tf.reset_default_graph()
...@@ -2244,6 +2252,7 @@ if __name__ == '__main__': ...@@ -2244,6 +2252,7 @@ if __name__ == '__main__':
test_forward_log_softmax() test_forward_log_softmax()
test_forward_bias_add() test_forward_bias_add()
test_forward_zeros_like() test_forward_zeros_like()
test_forward_erf()
# Reductions # Reductions
test_forward_argminmax() test_forward_argminmax()
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
import numpy as np import numpy as np
import tvm import tvm
import scipy
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
...@@ -67,6 +68,7 @@ def test_unary_op(): ...@@ -67,6 +68,7 @@ def test_unary_op():
for opfunc, ref in [(tvm.relay.log, np.log), for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp), (tvm.relay.exp, np.exp),
(tvm.relay.erf, scipy.special.erf),
(tvm.relay.sqrt, np.sqrt), (tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt), (tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid), (tvm.relay.sigmoid, sigmoid),
......
...@@ -46,6 +46,7 @@ using namespace tvm; ...@@ -46,6 +46,7 @@ using namespace tvm;
} }
TOPI_DECLARE_UNARY_OP(exp); TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log); TOPI_DECLARE_UNARY_OP(log);
......
...@@ -75,6 +75,23 @@ def exp(x): ...@@ -75,6 +75,23 @@ def exp(x):
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def erf(x):
"""Take gauss error function of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.erf(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def tanh(x): def tanh(x):
"""Take hyperbolic tanh of input x. """Take hyperbolic tanh of input x.
......
...@@ -28,12 +28,19 @@ from ..util import traverse_inline, get_const_tuple ...@@ -28,12 +28,19 @@ from ..util import traverse_inline, get_const_tuple
@autotvm.register_topi_compute(nn.dense, "cpu", "direct") @autotvm.register_topi_compute(nn.dense, "cpu", "direct")
def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
batch, _ = get_const_tuple(data.shape) target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
if bias is not None:
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
return C
M, _ = get_const_tuple(data.shape)
# For small batch sizes, don't pack weight into cache-friendly layout # For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension # because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use # TODO(icemelon9): use a more systematic way to determine which schedule to use
if batch <= 16: if M <= 16:
return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype) return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
return _declaration_dense_pack(cfg, data, weight, bias, out_dtype) return _declaration_dense_pack(cfg, data, weight, bias, out_dtype)
...@@ -41,35 +48,31 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): ...@@ -41,35 +48,31 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute with packing weight into cache-friendly layout # Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack") @autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target() if out_dtype is None:
if "cblas" in target.libs: out_dtype = data.dtype
C = cblas.matmul(data, weight, False, True) M, K = get_const_tuple(data.shape) # batch, in_dim
else: N, _ = get_const_tuple(weight.shape) # out_dim
if out_dtype is None: # create tuning space
out_dtype = data.dtype cfg.define_split("tile_y", M, num_outputs=3)
batch, in_dim = get_const_tuple(data.shape) cfg.define_split("tile_x", N, num_outputs=3)
out_dim, _ = get_const_tuple(weight.shape) cfg.define_split("tile_k", K, num_outputs=2)
# create tuning space if cfg.is_fallback:
cfg.define_split("tile_y", batch, num_outputs=3) _default_dense_pack_config(cfg, M, N, K)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2) packw_bn = cfg["tile_x"].size[-1]
if cfg.is_fallback: packw_shape = (N // packw_bn, K, packw_bn)
_default_dense_pack_config(cfg, batch, out_dim, in_dim) packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn) k = tvm.reduce_axis((0, K), name="k")
packw = tvm.compute(packw_shape, C = tvm.compute((M, N),
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
k = tvm.reduce_axis((0, in_dim), name="k") packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
C = tvm.compute((batch, out_dim), axis=k),
lambda y, x: tvm.sum( tag="dense_pack")
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
if bias is not None: if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype), C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST) tag=tag.BROADCAST)
return C return C
...@@ -77,34 +80,30 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): ...@@ -77,34 +80,30 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute without packing weight # Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack") @autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target() if out_dtype is None:
if "cblas" in target.libs: out_dtype = data.dtype
C = cblas.matmul(data, weight, False, True) M, K = get_const_tuple(data.shape)
else: N, _ = get_const_tuple(weight.shape)
if out_dtype is None: # create tuning space
out_dtype = data.dtype cfg.define_split("tile_y", M, num_outputs=2)
batch, in_dim = get_const_tuple(data.shape) cfg.define_split("tile_x", N, num_outputs=2)
out_dim, _ = get_const_tuple(weight.shape) cfg.define_split("tile_k", K, num_outputs=2)
# create tuning space if cfg.is_fallback:
cfg.define_split("tile_x", out_dim, num_outputs=2) _default_dense_nopack_config(cfg, M, N, K)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2) vec = cfg["tile_k"].size[-1]
if cfg.is_fallback: k = tvm.reduce_axis((0, K // vec), "k")
_default_dense_nopack_config(cfg, batch, out_dim, in_dim) CC = tvm.compute((M, N, vec),
lambda z, y, x: tvm.sum(
vec = cfg["tile_k"].size[-1] data[z, k * vec + x].astype(out_dtype) *
k = tvm.reduce_axis((0, in_dim // vec), "k") weight[y, k * vec + x].astype(out_dtype), axis=k))
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum( kk = tvm.reduce_axis((0, vec), "kk")
data[z, k * vec + x].astype(out_dtype) * C = tvm.compute((M, N),
weight[y, k * vec + x].astype(out_dtype), axis=k)) lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None: if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype), C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST) tag=tag.BROADCAST)
return C return C
......
...@@ -148,6 +148,11 @@ TVM_REGISTER_GLOBAL("topi.exp") ...@@ -148,6 +148,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]); *rv = exp(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
});
TVM_REGISTER_GLOBAL("topi.cos") TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]); *rv = cos(args[0]);
...@@ -157,7 +162,6 @@ TVM_REGISTER_GLOBAL("topi.sin") ...@@ -157,7 +162,6 @@ TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]); *rv = sin(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.tanh") TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]); *rv = tanh(args[0]);
......
...@@ -36,6 +36,7 @@ def test_ewise(): ...@@ -36,6 +36,7 @@ def test_ewise():
assert B.op.body[0].name == name assert B.op.body[0].name == name
test_apply(topi.exp, "exp") test_apply(topi.exp, "exp")
test_apply(topi.erf, "erf")
test_apply(topi.tanh, "tanh") test_apply(topi.tanh, "tanh")
test_apply(topi.sigmoid, "sigmoid") test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log") test_apply(topi.log, "log")
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import numpy as np import numpy as np
import scipy
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
...@@ -86,6 +87,7 @@ def test_ewise(): ...@@ -86,6 +87,7 @@ def test_ewise():
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True) test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi) test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi) test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
def test_cast(): def test_cast():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment