Commit 64870ffb by ziheng Committed by Tianqi Chen

[Contrib] CuDNN v7 Support (#311)

* [Contrib] CuDNN v7 Support

* Add test
parent 0ccc281d
......@@ -106,6 +106,7 @@ endif
include make/contrib/cblas.mk
include make/contrib/nnpack.mk
include make/contrib/cudnn.mk
ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS)
......
......@@ -31,6 +31,10 @@ ADD_CFLAGS =
# whether enable CUDA during compile
USE_CUDA = 0
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda
# whether enable OpenCL during compile
USE_OPENCL = 0
......@@ -52,9 +56,9 @@ USE_RPC = 0
# Whether use BLAS, choices: openblas, atlas, blas, apple
USE_BLAS = none
# Whether use NNPack
USE_NNPACK = 0
# NNPACK_PATH = none
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda
# Whether use CuDNN
USE_CUDNN = 0
CUDNN_CONTRIB_SRC = $(wildcard src/contrib/cudnn/*.cc)
CUDNN_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDNN_CONTRIB_SRC))
ifeq ($(USE_CUDNN), 1)
CFLAGS += -DTVM_USE_CUDNN=1 -I$(CUDA_PATH)/include
ADD_LDFLAGS += -lcudnn
RUNTIME_DEP += $(CUDNN_CONTRIB_OBJ)
endif
"""External function interface to CuDNN v7 library."""
# pylint: disable-msg=C0103
import ctypes
import numpy as np
from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func
# algos can be read from cudnn.h
_FWD_ALGOS = [
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED",
"CUDNN_CONVOLUTION_FWD_ALGO_COUNT",
]
_BWD_FILTER_ALGOS = [
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0",
# non-deterministic
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3",
# non-deterministic, algo0 with workspaceS
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD",
# not implemented
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT",
]
_BWD_DATA_ALGOS = [
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",
# non-deterministic
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT",
]
_ALGO_TYPE = [
"fwd",
"bwd_filter",
"bwd_data"
]
def algo_to_index(algo_type, algo_name):
"""Return a index represents the algorithm, which can be used in
calling CuDNN function
Parameters
----------
algo_type : str
["fwd", "bwd_filter", "bwd_data]
algo_name : str
algorithm name in cudnn definition
fwd = [
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED",
"CUDNN_CONVOLUTION_FWD_ALGO_COUNT",
]
bwd_filter = [
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0",
# non-deterministic
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3",
# non-deterministic, algo0 with workspaceS
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD",
# not implemented
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT",
]
bwd_data = [
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",
# non-deterministic
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT",
]
Returns
-------
algo: int
algorithm index
"""
idx = -1
if algo_type == "fwd":
idx = _FWD_ALGOS.index(algo_name)
elif algo_type == "bwd_filter":
idx = _BWD_FILTER_ALGOS.index(algo_name)
elif algo_type == "bwd_data":
idx = _BWD_DATA_ALGOS.index(algo_name)
assert idx >= 0
return idx
def _get_np_int32_array_handle(arr):
"""Return a void_p handle for a numpy array
Parameters
----------
arr: numpy.NDArray
source numpy array
Returns
-------
ptr: ctypes.c_void_p
pointer to the data
"""
assert arr.dtype == np.int32
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
return ctypes.cast(ptr, ctypes.c_void_p)
def conv2d_w_shape(in_channel,
out_channel,
filter_h,
filter_w):
"""Get weight shape for a 2D convolution
Parameters
----------
in_channel: int
input channel
out_channel: int
output channel
filter_h: int
filter height
filter_w: int
filter width
Returns
-------
wshape: list
weight shape
"""
return [out_channel, in_channel, filter_h, filter_w]
def conv2d_output_shape(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
x_shape,
w_shape):
"""Get output shape of 2D convolution
Paramters
---------
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
pad_h: int
height pad
pad_w: int
weight pad
stride_h: int
height stride
stride_w: int
width stride
dilation_h: int
height dilation
dilation_w: int
width dilation
x_shape: list
input shape
w_shape: list
weight shape
Returns
-------
oshape: list
output shape
"""
assert isinstance(x_shape, list)
assert isinstance(w_shape, list)
assert len(x_shape) == 4
assert len(w_shape) == 4
oshape = np.zeros((len(x_shape)), dtype=np.int32)
func = _get_global_func("tvm.contrib.cudnn.conv2d.output_shape")
func(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
x_shape[0].value,
x_shape[1].value,
x_shape[2].value,
x_shape[3].value,
w_shape[0].value,
w_shape[1].value,
w_shape[2].value,
w_shape[3].value,
_get_np_int32_array_handle(oshape))
return list(oshape)
def conv2d_forward(x,
w,
stride_h=1,
stride_w=1,
pad_h=0,
pad_w=0,
dilation_h=1,
dilation_w=1,
conv_mode=1,
tensor_format=0,
algo=0):
"""Create an extern op that compute 2D convolution with CuDNN
Parameters
----------
x: Tensor
input feature map
w: Tensor
convolution weight
stride_h: int
height stride
stride_w: int
width stride
pad_h: int
height pad
pad_w: int
weight pad
dilation_h: int
height dilation
dilation_w: int
width dilation
conv_mode: int
0: CUDNN_CONVOLUTION
1: CUDNN_CROSS_CORRELATION
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
algo: int
Forward algorithm, get index from ```algo_to_index``` function
Returns
-------
y: Tensor
The result tensor
"""
oshape = conv2d_output_shape(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape))
return _api.extern(
oshape, [x, w],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cudnn.conv2d.forward",
conv_mode,
tensor_format,
algo,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
ins[0],
ins[1],
outs[0]), name="y")
/*!
* Copyright (c) 2017 by Contributors
* \file Use external cudnn utils function
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/device_api.h>
#include "cudnn_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int mode = args[0];
int format = args[1];
int algo = args[2];
int pad_h = args[3];
int pad_w = args[4];
int stride_h = args[5];
int stride_w = args[6];
int dilation_h = args[7];
int dilation_w = args[8];
DLTensor *x = args[9];
DLTensor *w = args[10];
DLTensor *y = args[11];
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
// Set Ctx
entry_ptr->conv_entry.ctx = x->ctx;
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
// Set Desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
entry_ptr->conv_entry.mode,
entry_ptr->conv_entry.data_type));
// Set Filter
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type,
CUDNN_TENSOR_NCHW,
static_cast<int>(w->shape[0]),
static_cast<int>(w->shape[1]),
static_cast<int>(w->shape[2]),
static_cast<int>(w->shape[3])));
// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type,
static_cast<int>(x->shape[0]),
static_cast<int>(x->shape[1]),
static_cast<int>(x->shape[2]),
static_cast<int>(x->shape[3])));
// Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type,
static_cast<int>(y->shape[0]),
static_cast<int>(y->shape[1]),
static_cast<int>(y->shape[2]),
static_cast<int>(y->shape[3])));
// Set workspace
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.fwd_algo,
&workspace_size));
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle,
CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.input_desc,
x->data,
entry_ptr->conv_entry.filter_desc,
w->data,
entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.fwd_algo,
entry_ptr->conv_entry.workspace,
workspace_size,
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.output_desc,
y->data));
});
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
int format = args[0];
int pad_h = args[1];
int pad_w = args[2];
int stride_h = args[3];
int stride_w = args[4];
int dilation_h = args[5];
int dilation_w = args[6];
int x_dim0 = args[7];
int x_dim1 = args[8];
int x_dim2 = args[9];
int x_dim3 = args[10];
int w_dim0 = args[11];
int w_dim1 = args[12];
int w_dim2 = args[12];
int w_dim3 = args[14];
void *out_shape = args[15];
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));
// input desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format,
CUDNN_DATA_FLOAT,
x_dim0,
x_dim1,
x_dim2,
x_dim3));
// filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
w_dim0,
w_dim1,
w_dim2,
w_dim3));
CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
static_cast<int*>(out_shape),
static_cast<int*>(out_shape) + 1,
static_cast<int*>(out_shape) + 2,
static_cast<int*>(out_shape) + 3));
});
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external cudnn utils function
*/
#include "cudnn_utils.h"
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace contrib {
// CuDNN Data Type
cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) {
switch (dtype.code) {
case kInt:
if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8;
else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32;
else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4;
else
LOG(FATAL) << "Unsupported type";
break;
case kUInt:
LOG(FATAL) << "Unsupported type";
break;
case kFloat:
if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT;
else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE;
else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF;
else
LOG(FATAL) << "Unsupported type";
break;
}
return CUDNN_DATA_FLOAT;
}
template<>
const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) {
static const int int_v = 0;
static const float float_v = 0;
static const double double_v = 0;
if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) {
return static_cast<const void*>(&float_v);
}
if (type == CUDNN_DATA_DOUBLE) {
return static_cast<const void*>(&double_v);
}
if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) {
return static_cast<const void*>(&int_v);
}
return nullptr;
}
template<>
const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) {
static const int int_v = 1;
static const float float_v = 1.f;
static const double double_v = 1.f;
if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) {
return static_cast<const void*>(&float_v);
}
if (type == CUDNN_DATA_DOUBLE) {
return static_cast<const void*>(&double_v);
}
if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) {
return static_cast<const void*>(&int_v);
}
return nullptr;
}
// CuDNNThreadEntry
CuDNNThreadEntry::CuDNNThreadEntry() {
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
auto func = runtime::Registry::Get("device_api.gpu");
void *ret = (*func)();
cuda_api = static_cast<runtime::DeviceAPI*>(ret);
CUDNN_CALL(cudnnCreate(&handle));
CUDNN_CALL(cudnnSetStream(handle, stream));
conv_entry.cuda_api = cuda_api;
}
CuDNNThreadEntry::~CuDNNThreadEntry() {
CUDNN_CALL(cudnnDestroy(handle));
}
typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore;
CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() {
return CuDNNThreadStore::Get();
}
// ConvEntry
ConvEntry::ConvEntry() {
CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc));
CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc));
CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc));
}
ConvEntry::~ConvEntry() {
CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc));
CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc));
CleanWorkspace();
}
void ConvEntry::UpdateWorkspace(const size_t wsize) {
if (workspace_size < wsize) {
if (workspace != nullptr) {
CleanWorkspace();
}
workspace_size = wsize;
workspace = cuda_api->AllocWorkspace(ctx, workspace_size);
}
}
void ConvEntry::CleanWorkspace() {
if (workspace) cuda_api->FreeWorkspace(ctx, workspace);
workspace_size = 0;
}
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external cudnn utils function
*/
#ifndef TVM_CONTRIB_CUDNN_CUDNN_UTILS_H_
#define TVM_CONTRIB_CUDNN_CUDNN_UTILS_H_
#include <dmlc/logging.h>
#include <cudnn.h>
#include <tvm/runtime/device_api.h>
#include "../../runtime/cuda/cuda_common.h"
namespace tvm {
namespace contrib {
#define CUDNN_CALL(func) \
{ \
cudnnStatus_t e = (func); \
CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
}
/*! breif Convert DLTensor type to CuDNN type */
struct CuDNNDataType {
static cudnnDataType_t DLTypeToCuDNNType(const DLDataType &dtype);
template<int v>
static const void* GetConst(cudnnDataType_t type);
}; // struct CuDNNDataType
inline void GetStride(int nbdim, const int *dims, int *strides) {
int mul = 1;
for (int i = nbdim - 1; i >=0; --i) {
mul *= dims[i];
strides[i] = mul;
}
}
struct ConvEntry {
cudnnConvolutionDescriptor_t conv_desc;
cudnnConvolutionMode_t mode;
cudnnFilterDescriptor_t filter_desc;
cudnnDataType_t data_type;
cudnnTensorFormat_t tensor_format;
cudnnTensorDescriptor_t input_desc;
cudnnTensorDescriptor_t output_desc;
cudnnConvolutionFwdAlgo_t fwd_algo;
// cudnnMathType_t math_type;
TVMContext ctx;
runtime::DeviceAPI *cuda_api;
void *workspace{nullptr};
size_t workspace_size{0};
int group_count {0};
ConvEntry();
~ConvEntry();
void UpdateWorkspace(const size_t wsize);
void CleanWorkspace();
}; // ConvThreadEntry
struct CuDNNThreadEntry {
CuDNNThreadEntry();
~CuDNNThreadEntry();
cudnnHandle_t handle{nullptr};
ConvEntry conv_entry;
runtime::DeviceAPI *cuda_api{nullptr};
static CuDNNThreadEntry* ThreadLocal();
}; // CuDNNThreadEntry
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_CUDNN_CUDNN_UTILS_H_
......@@ -81,7 +81,7 @@ size_t InferTensorizeRegion(
for (const auto& kv : in_dom) {
Array<Range> vec;
const Tensor& t = kv.first;
for (int i = 0; i < t.ndim(); ++i) {
for (size_t i = 0; i < t.ndim(); ++i) {
Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
vec.push_back(std::move(r));
......@@ -333,7 +333,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
CHECK_EQ(inputs.size(), intrin->inputs.size())
<< "Tensorize failed: input size mismatch ";
// input binding
for (int i = 0; i < intrin->inputs.size(); ++i) {
for (size_t i = 0; i < intrin->inputs.size(); ++i) {
Tensor tensor = inputs[i];
Buffer buffer = intrin->buffers[i];
Array<NodeRef> bind_spec{buffer, tensor};
......
......@@ -153,7 +153,7 @@ class Vectorizer : public IRMutator {
base = BroadcastTo(base, lanes);
stride = BroadcastTo(stride, lanes);
Array<Expr> elems;
for (size_t i = 0; i < lanes; ++i) {
for (int i = 0; i < lanes; ++i) {
elems.push_back(
Ramp::make(Shuffle::make_extract_element(base, i),
Shuffle::make_extract_element(stride, i),
......
import tvm
from tvm.contrib import cudnn
import numpy as np
def test_conv2d():
in_channel = 3
out_channel = 32
filter_h = 3
filter_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
dilation_h = 1
dilation_w = 1
xshape = [4, 3, 32, 32]
if not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True):
print("skip because cudnn is not enabled...")
return
wshape = cudnn.conv2d_w_shape(in_channel,
out_channel,
filter_h,
filter_w)
X = tvm.placeholder(xshape, name='X')
W = tvm.placeholder(wshape, name='W')
Y = cudnn.conv2d_forward(X,
W,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
conv_mode=1,
tensor_format=0,
algo=1)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32),
ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32),
ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32),
ctx)
f(x, w, y)
verify()
if __name__ == "__main__":
test_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment