Commit 2f5b155a by Haichen Shen Committed by Yizhi Liu

[Relay/TOPI][Op] Add erf intrinsic and op (#3702)

* add more ops

* stop vectorization for erf

* x

* cleanup

* fix

* add whitelist for vectorizable intrin

* add tf converter

* fix dense

* fix

* add missing intrin

* fix mxnet frontend

* fix nvptx
parent 6a377f77
......@@ -512,6 +512,7 @@ TVM_DLL Expr trunc(Expr x);
} \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
......
......@@ -556,6 +556,9 @@ class Call : public ExprNode {
name == intrin_name);
}
/*! \return Whether call node can be vectorized. */
bool is_vectorizable() const;
static constexpr const char* _type_key = "Call";
TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);
......@@ -571,6 +574,9 @@ class Call : public ExprNode {
static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
};
/*!
......
......@@ -211,6 +211,22 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x)
def erf(x):
"""Take gauss error function of the input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "erf", x)
def tanh(x):
"""Take hyperbolic tanh of input x.
......
......@@ -170,8 +170,8 @@ class Executor(object):
return args
if kwargs and not isinstance(expr, Function):
raise Exception("can only supply keyword parameters for a \
relay.Function, found {0}".format(expr))
raise Exception("can only supply keyword parameters for a "
"relay.Function, found {0}".format(expr))
params = expr.params
param_names = [p.name_hint for p in params]
......@@ -182,16 +182,16 @@ class Executor(object):
if i < num_of_args:
if kwargs.get(name):
raise Exception(
"duplicate argument supplied in \
both positional args (at position: {0}), \
and keyword argument (with name: {1})".format(i, name))
"duplicate argument supplied in "
"both positional args (at position: {0}), "
"and keyword argument (with name: {1})".format(i, name))
else:
cargs.append(kwargs[name])
if len(cargs) != len(params):
raise Exception(
"insufficient arguments, expected" \
" {0}, provided {1}".format(len(cargs), len(params)))
"insufficient arguments, expected "
"{0}, provided {1}".format(len(cargs), len(params)))
return tuple(cargs)
......
......@@ -124,7 +124,16 @@ class StrAttrsDict(object):
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()[]').split(',') if x)
ret = []
for x in tshape.strip('()[]').split(','):
x = x.strip()
if not x:
continue
if x == "None":
ret.append(None)
else:
ret.append(int(x))
return tuple(ret)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
......
......@@ -55,10 +55,17 @@ def _mx_fully_connected(inputs, attrs):
use_flatten = attrs.get_bool("flatten", True)
if has_flatten and use_flatten:
inputs[0] = _op.nn.batch_flatten(inputs[0])
data_shape = _infer_type(inputs[0]).checked_type.shape
if len(data_shape) > 2:
inputs[0] = _op.reverse_reshape(inputs[0], [-1, 0])
res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=-1)
if len(data_shape) > 2:
new_shape = data_shape[:-1]
new_shape.append(units)
res = _op.reshape(res, new_shape)
return res
......@@ -241,8 +248,8 @@ def _mx_layer_norm(inputs, attrs):
def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
end = attrs.get_int_tuple('end', None)
begin = list(attrs.get_int_tuple('begin', None))
end = list(attrs.get_int_tuple('end', None))
stride = attrs.get_int_tuple('step', None)
if begin is None:
raise tvm.error.OpAttributeRequired(
......@@ -251,11 +258,12 @@ def _mx_slice(inputs, attrs):
raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.')
if None in begin:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "begin" of operator Slice is not valid.')
if None in end:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "end" of operator Slice is not valid.')
data_shape = _infer_type(inputs[0]).checked_type.shape
for i, beg in enumerate(begin):
if beg is None:
assert end[i] is None
begin[i] = 0
end[i] = data_shape[i]
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['strides'] = stride
......@@ -497,7 +505,8 @@ def _mx_arange(inputs, attrs):
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)
......@@ -910,6 +919,7 @@ def _mx_one_hot(inputs, attrs):
_identity_list = [
"log",
"exp",
"erf",
"sqrt",
"floor",
"ceil",
......
......@@ -1261,6 +1261,7 @@ _convert_map = {
'DepthToSpace' : _depth_to_space(),
'Equal' : _broadcast('equal'),
'Elu' : _elu(),
'Erf' : AttrCvt('erf'),
'Exp' : AttrCvt('exp'),
'ExpandDims' : _expand_dims(),
'Fill' : _fill(),
......
......@@ -30,6 +30,7 @@ register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("erf", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
......
......@@ -92,6 +92,22 @@ def exp(data):
return _make.exp(data)
def erf(data):
"""Compute elementwise error function of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.erf(data)
def sqrt(data):
"""Compute elementwise sqrt of data.
......
......@@ -31,6 +31,9 @@ namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);
......
......@@ -92,6 +92,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);
......
......@@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
.set_body(DispatchExternLibDevice);
......
......@@ -176,6 +176,22 @@ Expr Let::make(Var var, Expr value, Expr body) {
return Expr(node);
}
const char* Call::vectorizable_intrinsics[] = {
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
"log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
ir::Call::likely, ir::Call::popcount
};
bool Call::is_vectorizable() const {
size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
for (size_t i = 0; i < cnt; ++i) {
if (name == Call::vectorizable_intrinsics[i]) {
return true;
}
}
return false;
}
Expr Call::make(DataType type,
std::string name,
Array<Expr> args,
......
......@@ -268,16 +268,34 @@ class Vectorizer : public IRMutator {
if (op->name == intrinsic::tvm_if_then_else) {
return MutateIfThenElseExpr_(op, e);
}
int lane = 0;
Array<Expr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return e;
if (!op->is_vectorizable()) {
// Cannot vectorize this op
Array<Expr> new_args;
for (auto arg : op->args) {
auto new_arg = this->Mutate(arg);
if (new_arg.type().is_vector()) {
need_scalarize_ = true;
return e;
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type, op->name, new_args, op->call_type, op->func, op->value_index);
}
} else {
return Call::make(
op->type.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
int lane = 0;
Array<Expr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
}
}
}
// Load
......
......@@ -85,6 +85,18 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\erf(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise.
......
......@@ -1844,6 +1844,14 @@ def test_forward_zeros_like():
_test_forward_zeros_like((2, 3, 11), "float32")
_test_forward_zeros_like((2, 3, 11), "float64")
def test_forward_erf():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')
def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
......@@ -2244,6 +2252,7 @@ if __name__ == '__main__':
test_forward_log_softmax()
test_forward_bias_add()
test_forward_zeros_like()
test_forward_erf()
# Reductions
test_forward_argminmax()
......
......@@ -16,6 +16,7 @@
# under the License.
import numpy as np
import tvm
import scipy
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
......@@ -67,6 +68,7 @@ def test_unary_op():
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.erf, scipy.special.erf),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid),
......
......@@ -46,6 +46,7 @@ using namespace tvm;
}
TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
......
......@@ -75,6 +75,23 @@ def exp(x):
@tvm.tag_scope(tag=tag.ELEMWISE)
def erf(x):
"""Take gauss error function of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.erf(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def tanh(x):
"""Take hyperbolic tanh of input x.
......
......@@ -28,12 +28,19 @@ from ..util import traverse_inline, get_const_tuple
@autotvm.register_topi_compute(nn.dense, "cpu", "direct")
def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
batch, _ = get_const_tuple(data.shape)
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
if bias is not None:
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
return C
M, _ = get_const_tuple(data.shape)
# For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use
if batch <= 16:
if M <= 16:
return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
return _declaration_dense_pack(cfg, data, weight, bias, out_dtype)
......@@ -41,35 +48,31 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)
packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
if out_dtype is None:
out_dtype = data.dtype
M, K = get_const_tuple(data.shape) # batch, in_dim
N, _ = get_const_tuple(weight.shape) # out_dim
# create tuning space
cfg.define_split("tile_y", M, num_outputs=3)
cfg.define_split("tile_x", N, num_outputs=3)
cfg.define_split("tile_k", K, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, M, N, K)
packw_bn = cfg["tile_x"].size[-1]
packw_shape = (N // packw_bn, K, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
k = tvm.reduce_axis((0, K), name="k")
C = tvm.compute((M, N),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
return C
......@@ -77,34 +80,30 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)
vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))
kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if out_dtype is None:
out_dtype = data.dtype
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", M, num_outputs=2)
cfg.define_split("tile_x", N, num_outputs=2)
cfg.define_split("tile_k", K, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, M, N, K)
vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, K // vec), "k")
CC = tvm.compute((M, N, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))
kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((M, N),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
return C
......
......@@ -148,6 +148,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
});
TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
......@@ -157,7 +162,6 @@ TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
......
......@@ -36,6 +36,7 @@ def test_ewise():
assert B.op.body[0].name == name
test_apply(topi.exp, "exp")
test_apply(topi.erf, "erf")
test_apply(topi.tanh, "tanh")
test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import scipy
import tvm
import topi
import topi.testing
......@@ -86,6 +87,7 @@ def test_ewise():
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
def test_cast():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment