Unverified Commit 799ff356 by Haichen Shen Committed by GitHub

[Runtime][Contrib] Support cudnn softmax (#5214)

parent 02eb1833
......@@ -402,3 +402,27 @@ def conv_forward(x,
ins[1],
outs[0],
conv_dtype), name="y")
def softmax(x, axis=-1):
"""Compute softmax using CuDNN
Parameters
----------
x : tvm.te.Tensor
The input tensor
axis : int
The axis to compute the softmax
Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape, [x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.softmax.forward",
ins[0],
outs[0],
axis), name="y")
......@@ -34,12 +34,12 @@ reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
# softmax
reg.register_schedule("nn.softmax", strategy.schedule_softmax)
reg.register_strategy("nn.softmax", strategy.softmax_strategy)
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
# log_softmax
reg.register_schedule("nn.log_softmax", strategy.schedule_softmax)
reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
......
......@@ -69,6 +69,11 @@ class DenseAttrs(Attrs):
"""Attributes for nn.dense"""
@tvm._ffi.register_object("relay.attrs.SoftmaxAttrs")
class SoftmaxAttrs(Attrs):
"""Attributes for nn.softmax"""
@tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs")
class FIFOBufferAttrs(Attrs):
"""Attributes for nn.fifo_buffer"""
......
......@@ -60,9 +60,25 @@ def schedule_adaptive_pool_cuda(attrs, outs, target):
with target:
return topi.cuda.schedule_adaptive_pool(outs)
@schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax_cuda(attrs, outs, target):
"""schedule softmax for cuda"""
@softmax_strategy.register(["cuda", "gpu"])
def softmax_strategy_cuda(attrs, inputs, out_type, target):
"""softmax cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="softmax.cuda")
if target.target_name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(topi.cuda.softmax_cudnn),
wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
name="softmax.cudnn",
plevel=15)
return strategy
@schedule_log_softmax.register(["cuda", "gpu"])
def schedule_log_softmax_cuda(attrs, outs, target):
"""scheudle log_softmax for cuda"""
with target:
return topi.cuda.schedule_softmax(outs)
......
......@@ -107,9 +107,27 @@ def schedule_adaptive_pool(attrs, outs, target):
return topi.generic.schedule_adaptive_pool(outs)
# softmax
def wrap_compute_softmax(topi_compute):
"""Wrap softmax topi compute"""
def _compute_softmax(attrs, inputs, out_type):
axis = attrs.get_int("axis")
return [topi_compute(inputs[0], axis)]
return _compute_softmax
@override_native_generic_func("softmax_strategy")
def softmax_strategy(attrs, inputs, out_type, target):
"""softmax generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implemenation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.generic.schedule_softmax),
name="softmax.generic")
return strategy
# log_softmax
@generic_func
def schedule_softmax(attrs, outs, target):
"""Schedule softmax"""
def schedule_log_softmax(attrs, outs, target):
"""Schedule log_softmax op"""
with target:
return topi.generic.schedule_softmax(outs)
......
......@@ -50,9 +50,19 @@ def schedule_adaptive_pool_hls(attrs, outs, target):
with target:
return topi.hls.schedule_adaptive_pool(outs)
@schedule_softmax.register("hls")
def schedule_softmax_hls(attrs, outs, target):
"""schedule softmax for hls"""
@softmax_strategy.register("hls")
def softmax_strategy_hls(attrs, inputs, out_type, target):
"""softmax hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.hls.schedule_softmax),
name="softmax.hls")
return strategy
@schedule_log_softmax.register("hls")
def schedule_log_softmax_hls(attrs, inputs, out_type, target):
"""schedule log_softmax for hls"""
with target:
return topi.hls.schedule_softmax(outs)
......
......@@ -44,9 +44,19 @@ def schedule_adaptive_pool_opengl(attrs, outs, target):
with target:
return topi.opengl.schedule_adaptive_pool(outs)
@schedule_softmax.register("opengl")
def schedule_softmax_opengl(attrs, outs, target):
"""schedule softmax for opengl"""
@softmax_strategy.register("opengl")
def softmax_strategy_opengl(attrs, inputs, out_type, target):
"""softmax opengl strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.opengl.schedule_softmax),
name="softmax.opengl")
return strategy
@schedule_log_softmax.register("opengl")
def schedule_log_softmax_opengl(attrs, outs, target):
"""schedule log_softmax for opengl"""
with target:
return topi.opengl.schedule_softmax(outs)
......
......@@ -55,9 +55,19 @@ def schedule_adaptive_pool_cpu(attrs, outs, target):
with target:
return topi.x86.schedule_adaptive_pool(outs)
@schedule_softmax.register("cpu")
def schedule_softmax_cpu(attrs, outs, target):
"""schedule softmax for x86"""
@softmax_strategy.register("cpu")
def softmax_strategy_cpu(attrs, inputs, out_type, target):
"""softmax x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.x86.schedule_softmax),
name="softmax.x86")
return strategy
@schedule_log_softmax.register("cpu")
def schedule_log_softmax_cpu(attrs, outs, target):
"""schedule log_softmax op for x86"""
with target:
return topi.x86.schedule_softmax(outs)
......
......@@ -347,14 +347,7 @@ RELAY_REGISTER_OP("nn.softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
return Array<te::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
});
.add_type_rel("Identity", IdentityRel);
// relay.nn.log_softmax
......
......@@ -140,5 +140,15 @@ void ConvEntry::CleanWorkspace() {
workspace_size = 0;
}
// SoftmaxEntry
SoftmaxEntry::SoftmaxEntry() {
CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc));
}
SoftmaxEntry::~SoftmaxEntry() {
CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc));
}
} // namespace contrib
} // namespace tvm
......@@ -85,12 +85,20 @@ struct ConvEntry {
void CleanWorkspace();
}; // ConvThreadEntry
struct SoftmaxEntry {
cudnnSoftmaxMode_t mode;
cudnnDataType_t data_type;
cudnnTensorDescriptor_t shape_desc;
SoftmaxEntry();
~SoftmaxEntry();
}; // SoftmaxEntry
struct CuDNNThreadEntry {
CuDNNThreadEntry();
~CuDNNThreadEntry();
cudnnHandle_t handle{nullptr};
ConvEntry conv_entry;
SoftmaxEntry softmax_entry;
runtime::DeviceAPI *cuda_api{nullptr};
static CuDNNThreadEntry* ThreadLocal();
}; // CuDNNThreadEntry
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/contrib/cudnn/softmax.cc
* \brief Use external cudnn softmax function
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include "cudnn_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
CHECK(axis >= 0 && axis < ndim);
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
// Set mode and shape descriptor
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
CUDNN_TENSOR_NCHW,
entry_ptr->softmax_entry.data_type,
static_cast<int>(N),
static_cast<int>(shape[ndim - 1]),
1,
1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
CUDNN_TENSOR_NCHW,
entry_ptr->softmax_entry.data_type,
static_cast<int>(pre_axis_dim),
static_cast<int>(shape[axis]),
static_cast<int>(post_axis_dim),
1));
}
auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle,
CUDNN_SOFTMAX_ACCURATE,
entry_ptr->softmax_entry.mode,
alpha,
entry_ptr->softmax_entry.shape_desc,
x->data,
beta,
entry_ptr->softmax_entry.shape_desc,
y->data));
});
} // namespace contrib
} // namespace tvm
......@@ -158,6 +158,52 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0)
def verify_softmax(shape, axis, dtype="float32"):
A = te.placeholder(shape, dtype=dtype, name='A')
B = cudnn.softmax(A, axis)
s = te.create_schedule([B.op])
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = topi.testing.softmax_python(a_np)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)
def verify_softmax_4d(shape, dtype="float32"):
A = te.placeholder(shape, dtype=dtype, name='A')
B = cudnn.softmax(A, axis=1)
s = te.create_schedule([B.op])
ctx = tvm.gpu(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3)
def test_softmax():
if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
verify_softmax((32, 10), -1)
verify_softmax((3, 4), -1)
verify_softmax((1, 5), -1, "float64")
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256), "float64")
if __name__ == "__main__":
test_conv2d()
test_conv3d()
test_softmax()
......@@ -34,7 +34,7 @@ from .conv3d import *
from .conv3d_winograd import *
from . import conv3d_alter_op
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .softmax import *
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import *
from .pooling import *
......
......@@ -17,6 +17,8 @@
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
from tvm import te
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective_from_existing
......@@ -79,3 +81,13 @@ def schedule_softmax(outs):
s[softmax].bind(tx, thread_x)
return s
def softmax_cudnn(x, axis=-1):
"""Perform softmax on the data using cudnn"""
return cudnn.softmax(x, axis)
def schedule_softmax_cudnn(outs):
"""Schedule for softmax cudnn op"""
return generic.schedule_extern(outs)
......@@ -77,7 +77,6 @@ def softmax(x, axis=-1):
return te.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
name='T_softmax_norm', attrs={"axis" : axis})
@tvm.te.tag_scope(tag='log_softmax_output')
def log_softmax(x):
"""Perform log softmax activation on the data
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment