Commit cc09497e by Wuwei Lin Committed by ziheng

[Relay, Quantization, TOPI] int8 dense on CUDA & Dense op quantization (#2877)

* Quantize dense layers

* Add out_dtype arggument to dense; Add dense_int8 on CUDA

* Add topi unittest of dense int8

* Fix relay

* Fix topi integration

* Fix quantization

* Update dense_rewrite

* Triger CI

* Change qconfig quantize_dense to quantize_op

* Fix

* Remove quantize_op from qconfig
parent 879f91d0
......@@ -336,10 +336,16 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
DataType out_dtype;
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
......
......@@ -188,7 +188,7 @@ class TaskExtractEnv:
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias = args
data, weight, bias, _ = args
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
......
......@@ -51,7 +51,9 @@ reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
@reg.register_compute("nn.dense")
def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
return [topi.nn.dense(inputs[0], inputs[1])]
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]
@reg.register_schedule("nn.dense")
def schedule_dense(attrs, outputs, target):
......
......@@ -475,7 +475,7 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)
def dense(data, weight, units=None):
def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
......@@ -494,12 +494,15 @@ def dense(data, weight, units=None):
units : int, optional
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.dense(data, weight, units)
return _make.dense(data, weight, units, out_dtype)
def relu(data):
......
......@@ -171,6 +171,26 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
......
......@@ -131,8 +131,12 @@ bool DenseRel(const Array<Type>& types,
oshape.Set((oshape.size() - 1), wshape[0]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
......@@ -140,9 +144,11 @@ bool DenseRel(const Array<Type>& types,
// Positional relay function to create dense operator used by frontend FFI.
Expr MakeDense(Expr data,
Expr weight,
IndexExpr units) {
IndexExpr units,
DataType out_dtype) {
auto attrs = make_node<DenseAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
......
......@@ -296,6 +296,39 @@ RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
Expr DenseRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
auto attrs = make_node<DenseAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.dense")
.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
Expr MulRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
......
......@@ -44,13 +44,15 @@ namespace cuda {
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
* \param out_dtype Output data type. Used for mixed precision.
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_cuda(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
const tvm::Tensor& bias,
const Type& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
......@@ -62,6 +64,7 @@ inline tvm::Tensor dense_cuda(const Target& target,
auto out_dim = weight->shape[0];
if (target->libs().count("cublas")) {
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
......@@ -72,7 +75,7 @@ inline tvm::Tensor dense_cuda(const Target& target,
return mm;
} else {
return topi::nn::dense(data, weight, bias);
return topi::nn::dense(data, weight, bias, out_dtype);
}
}
......
......@@ -40,12 +40,14 @@ using namespace tvm;
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
* \param out_dtype Output data type. Used for mixed precision.
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense(const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
const tvm::Tensor& bias,
const Type& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
......@@ -60,14 +62,15 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
auto matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return tvm::sum(data(i, k) * weight(j, k), { k });
return tvm::sum(tvm::cast(out_dtype, data(i, k)) *
tvm::cast(out_dtype, weight(j, k)), { k });
}, "tensor", "dense");
if (bias.defined()) {
matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return matmul(i, j) + bias(j);
return matmul(i, j) + tvm::cast(out_dtype, bias(j));
}, "tensor", kBroadcast);
}
......
......@@ -45,13 +45,15 @@ namespace rocm {
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
* \param out_dtype Output data type. Used for mixed precision.
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_rocm(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
const tvm::Tensor& bias,
const Type& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
......@@ -63,6 +65,7 @@ inline tvm::Tensor dense_rocm(const Target& target,
auto out_dim = weight->shape[0];
if (target->libs().count("rocblas")) {
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
......@@ -73,7 +76,7 @@ inline tvm::Tensor dense_rocm(const Target& target,
return mm;
} else {
return topi::nn::dense(data, weight, bias);
return topi::nn::dense(data, weight, bias, out_dtype);
}
}
......
......@@ -2,7 +2,8 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, group_conv2d_nchw
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \
group_conv2d_nchw, dense
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
......@@ -10,7 +11,6 @@ from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
......
......@@ -18,13 +18,17 @@
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
import tvm.autotvm as autotvm
from tvm.contrib import cublas
from .tensor_intrin import dp4a
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic
from ..util import traverse_inline, get_const_tuple
@dense.register("cuda")
def dense_cuda(data, weight, bias=None):
@autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct")
def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for cuda backend.
Parameters
......@@ -43,25 +47,29 @@ def dense_cuda(data, weight, bias=None):
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
# pylint: disable=unused-argument
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = data.shape
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "cublas" in target.libs:
assert out_dtype == data.dtype, "Mixed precision not supported."
matmul = cublas.matmul(data, weight, False, True)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
return dense_default(data, weight, bias)
return dense_default(data, weight, bias, out_dtype)
@generic.schedule_dense.register(["cuda", "gpu"])
def schedule_dense(outs):
@autotvm.register_topi_schedule(generic.schedule_dense, ["cuda", "gpu"], "direct")
def schedule_dense(cfg, outs):
"""Schedule for dense operator.
Parameters
......@@ -75,6 +83,7 @@ def schedule_dense(outs):
s: Schedule
The computation schedule for dense.
"""
# pylint: disable=unused-argument
target = tvm.target.current_target()
if target.target_name == "cuda" and "cublas" in target.libs:
return generic.schedule_extern(outs)
......@@ -124,3 +133,112 @@ def schedule_dense(outs):
traverse(outs[0].op)
return s
@autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA"""
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim),
lambda i, j: tvm.sum(data[i, k].astype(out_dtype) *
weight[j, k].astype(out_dtype), axis=[k]),
tag="dense_int8")
cfg.add_flop(batch * in_dim * out_dim * 2)
if bias is not None:
matmul = tvm.compute((batch, out_dim),
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
cfg.add_flop(batch * out_dim)
return matmul
@autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8'])
def schedule_dense_int8(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if "dense_int8" in op.tag:
_schedule_dense_int8(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
_dp4a = dp4a('shared', 'shared', 'local')
def _schedule_dense_int8(cfg, s, output):
data, weight = s[output].op.input_tensors
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
in_dim_factor = 4
assert in_dim % in_dim_factor == 0, "Input dimension must divide {}".format(in_dim_factor)
if in_dim % 16 == 0:
in_dim_factor = 16
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=4)
cfg.define_split("tile_x", out_dim, num_outputs=4)
cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2)
cfg.define_knob('auto_unroll_max_step', [0, 512, 1500])
# create cache stage
AA = s.cache_read(data, 'shared', [output])
WW = s.cache_read(weight, 'shared', [output])
CC = s.cache_write(output, 'local')
# handle bias
if output.op not in s.outputs:
s[output].compute_inline()
output = s.outputs[0].output(0)
n, x = s[output].op.axis
# this is the scope to attach global config inside this kernel
kernel_scope, n = s[output].split(n, nparts=1)
ko = CC.op.reduce_axis[0]
ko, ki = s[CC].split(ko, factor=4)
ko, kt = cfg['tile_k'].apply(s, CC, ko)
s[CC].tensorize(ki, _dp4a)
by, vy, ty, yi = cfg['tile_y'].apply(s, output, n)
bx, vx, tx, xi = cfg['tile_x'].apply(s, output, x)
s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi)
s[output].bind(by, tvm.thread_axis('blockIdx.y'))
s[output].bind(bx, tvm.thread_axis('blockIdx.x'))
s[output].bind(vy, tvm.thread_axis('vthread'))
s[output].bind(vx, tvm.thread_axis('vthread'))
s[output].bind(ty, tvm.thread_axis('threadIdx.y'))
s[output].bind(tx, tvm.thread_axis('threadIdx.x'))
n_ty = cfg['tile_y'].size[2]
n_tx = cfg['tile_x'].size[2]
s[CC].compute_at(s[output], tx)
yo, xo = CC.op.axis[:2]
s[CC].reorder(ko, kt, yo, xo, ki)
for load in [AA, WW]:
s[load].compute_at(s[CC], ko)
outer, inner = s[load].split(s[load].op.axis[-1], factor=in_dim_factor)
s[load].vectorize(inner)
fused = s[load].op.axis[:-1] + [outer]
fused = s[load].fuse(*fused)
fused, tx = s[load].split(fused, factor=n_tx)
fused, ty = s[load].split(fused, factor=n_ty)
s[load].bind(tx, tvm.thread_axis('threadIdx.x'))
s[load].bind(ty, tvm.thread_axis('threadIdx.y'))
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', False)
return s
......@@ -19,7 +19,7 @@ from __future__ import absolute_import
import tvm
from .. import tag
def dense_default(data, weight, bias=None):
def dense_default(data, weight, bias=None, out_dtype=None):
"""The default implementation of dense in topi.
Parameters
......@@ -33,6 +33,9 @@ def dense_default(data, weight, bias=None):
bias : tvm.Tensor, optional
1-D with shape [out_dim]
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
......@@ -42,21 +45,24 @@ def dense_default(data, weight, bias=None):
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = data.shape
out_dim, _ = weight.shape
k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim), \
lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
lambda i, j: tvm.sum(data[i, k].astype(out_dtype) * \
weight[j, k].astype(out_dtype), axis=k), \
name='T_dense', tag='dense')
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
tag=tag.BROADCAST)
return matmul
@tvm.target.override_native_generic_func("dense")
def dense(data, weight, bias=None):
def dense(data, weight, bias=None, out_dtype=None):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
Parameters
......@@ -70,9 +76,12 @@ def dense(data, weight, bias=None):
bias : tvm.Tensor, optional
1-D with shape [out_dim]
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
return dense_default(data, weight, bias)
return dense_default(data, weight, bias, out_dtype)
......@@ -25,7 +25,7 @@ from .. import tag
from .. import generic
@dense.register("rocm")
def dense_rocm(data, weight, bias=None):
def dense_rocm(data, weight, bias=None, out_dtype=None):
"""Dense operator for rocm backend.
Parameters
......@@ -39,6 +39,9 @@ def dense_rocm(data, weight, bias=None):
bias : tvm.Tensor, optional
1-D with shape [out_dim]
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
......@@ -48,17 +51,20 @@ def dense_rocm(data, weight, bias=None):
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = data.shape
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "rocblas" in target.libs:
assert out_dtype == data.dtype, "Mixed precision not supported."
matmul = rocblas.matmul(data, weight, False, True)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
return dense_default(data, weight, bias)
return dense_default(data, weight, bias, out_dtype)
@generic.schedule_dense.register(["rocm"])
......
......@@ -26,20 +26,22 @@ from .. import generic, tag, nn
from ..util import traverse_inline, get_const_tuple
@autotvm.register_topi_compute(nn.dense, "cpu", "direct")
def _declaration_dense(cfg, data, weight, bias=None):
def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
batch, _ = get_const_tuple(data.shape)
# For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use
if batch <= 16:
return _declaration_dense_nopack(cfg, data, weight, bias)
return _declaration_dense_pack(cfg, data, weight, bias)
return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
return _declaration_dense_pack(cfg, data, weight, bias, out_dtype)
# Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None):
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
......@@ -57,18 +59,21 @@ def _declaration_dense_pack(cfg, data, weight, bias=None):
k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k] * packw[x // packw_bn, k, x % packw_bn],
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j],
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
return C
# Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None):
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
......@@ -82,14 +87,15 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None):
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x] * weight[y, k * vec + x], axis=k))
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))
kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j],
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
return C
......
......@@ -403,7 +403,7 @@ TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
/* Ops from nn/dense.h */
TVM_REGISTER_GLOBAL("topi.nn.dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::dense(args[0], args[1], args[2]);
*rv = nn::dense(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/bias_add.h */
......@@ -544,7 +544,7 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3]);
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
......@@ -565,7 +565,7 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_l2_normalize")
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3]);
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
......@@ -686,7 +686,8 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias)>;
const tvm::Tensor& bias,
const Type& out_dtype)>;
/*!
* \brief Helper function for registering dense ops matching the
......@@ -703,8 +704,9 @@ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias = args[2];
Type out_dtype = args[3];
*ret = builder(target, data, weight, bias);
*ret = builder(target, data, weight, bias, out_dtype);
});
}
......@@ -712,8 +714,9 @@ TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
return topi::nn::dense(data, weight, bias);
const tvm::Tensor& bias,
const Type& out_dtype) {
return topi::nn::dense(data, weight, bias, out_dtype);
}))
.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda))
.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm));
......
......@@ -32,7 +32,7 @@ def get_all_backend():
'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu']
class NCHWcInt8Fallback(autotvm.FallbackContext):
class Int8Fallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
......
......@@ -25,7 +25,7 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend, NCHWcInt8Fallback
from common import get_all_backend, Int8Fallback
oc_block_factor = 4
......@@ -105,7 +105,7 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
def test_conv2d_nchw():
with NCHWcInt8Fallback():
with Int8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, 0)
......
......@@ -22,7 +22,7 @@ import topi.testing
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from common import get_all_backend
from common import get_all_backend, Int8Fallback
def verify_dense(batch, in_dim, out_dim, use_bias=True):
A = tvm.placeholder((batch, in_dim), name='A')
......@@ -65,6 +65,55 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
for device in get_all_backend():
check_device(device)
def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
dtype = 'int8'
out_dtype = 'int32'
A = tvm.placeholder((batch, in_dim), name='A', dtype=dtype)
B = tvm.placeholder((out_dim, in_dim), name='B', dtype=dtype)
C = tvm.placeholder((out_dim,), name='C', dtype=out_dtype)
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense_int8")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype)
b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
d_np = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))
if use_bias:
d_np += c_np
d_np = np.maximum(d_np, 0.0)
return (a_np, b_np, c_np, d_np)
# get the test data
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
D = topi.nn.dense(A, B, C if use_bias else None, out_dtype=out_dtype)
D = topi.nn.relu(D)
s = topi.generic.schedule_dense([D])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), ctx)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda']:
check_device(device)
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)
......@@ -72,5 +121,12 @@ def test_dense():
verify_dense(2, 1024, 1000, use_bias=True)
def test_dense_int8():
with Int8Fallback():
verify_dense_int8(2, 1024, 1000, use_bias=True)
verify_dense_int8(2, 1024, 1000, use_bias=False)
if __name__ == "__main__":
test_dense()
test_dense_int8()
......@@ -25,7 +25,7 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend, NCHWcInt8Fallback
from common import get_all_backend, Int8Fallback
def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
......@@ -203,7 +203,7 @@ def test_group_conv2d_nchw():
def test_group_conv2d_NCHWc_int8():
with NCHWcInt8Fallback():
with Int8Fallback():
# ResNeXt-50 workload
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 256, 56, 256, 3, 2, 1, 1, 32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment