/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #ifdef TF_TVMDSOOP_ENABLE_GPU #include <cuda_runtime.h> #endif #include <dlpack/dlpack.h> #include <tvm/runtime/device_api.h> #include <tvm/runtime/module.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include "tensorflow/core/framework/op_kernel.h" typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; typedef tensorflow::gtl::InlinedVector<tensorflow::int64, 4> ShapeContainer; using tensorflow::OpKernel; using tensorflow::OpKernelConstruction; using tensorflow::OpKernelContext; using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgsSetter; using tvm::runtime::TVMRetValue; // Op utility trait for diffrent device type template template <typename DEVICE_TYPE> class TVMDSOOpTrait; // Buffer information used for actual computation. // Each buffer is associated with one TensorFlow tensor // whose underlying buffer is record into "origin_buf". // For input tensor, we copy data from origin_buf to buf // and for output tensor, copy data from buf to origin_buf class TensorAsBuf { public: tensorflow::Tensor inline_tensor; tensorflow::Tensor* tensor; size_t size; size_t offset; int device_type; char* origin_buf; char* buf; void CopyToOrigin() { if (buf == origin_buf) { return; } if (device_type == kDLCPU) { memcpy(origin_buf, buf + offset, size); #ifdef TF_TVMDSOOP_ENABLE_GPU } else if (device_type == kDLGPU) { cudaMemcpy(origin_buf, buf + offset, size, cudaMemcpyDeviceToDevice); #endif } else { LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type << " is not implemented currently"; } } void CopyFromOrigin() { if (buf == origin_buf) { return; } if (device_type == kDLCPU) { memcpy(buf + offset, origin_buf, size); #ifdef TF_TVMDSOOP_ENABLE_GPU } else if (device_type == kDLGPU) { cudaMemcpy(buf + offset, origin_buf, size, cudaMemcpyDeviceToDevice); #endif } else { LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type << " is not implemented currently"; } } }; tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) { auto dtype = tf_tensor.dtype(); if (dtype == tensorflow::DT_FLOAT) { *res = {kDLFloat, 32, 1}; } else if (dtype == tensorflow::DT_INT64) { *res = {kDLInt, 64, 1}; } else if (dtype == tensorflow::DT_INT32) { *res = {kDLInt, 32, 1}; } else { return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype"); } return tensorflow::Status::OK(); } // Ensure buffer used for actual computation take 64byte alignment void EnsureAlignment(OpKernelContext* ctx, const tensorflow::Tensor& tensor, TensorAsBuf* out) { char* buf = const_cast<char*>(tensor.tensor_data().data()); out->origin_buf = buf; out->size = tensor.TotalBytes(); int alignment = 64; char* aligned = reinterpret_cast<char*>(((uint64_t)buf + alignment - 1) & (~(alignment - 1))); if (buf == aligned) { out->tensor = const_cast<tensorflow::Tensor*>(&tensor); out->buf = buf; out->offset = 0; } else { tensorflow::TensorShape buf_shape; tensorflow::int64 dims[1] = {(tensorflow::int64)(tensor.TotalBytes() + alignment)}; tensorflow::TensorShapeUtils::MakeShape(dims, 1, &buf_shape); out->tensor = &out->inline_tensor; ctx->allocate_temp(tensor.dtype(), buf_shape, out->tensor); buf = const_cast<char*>(out->tensor->tensor_data().data()); char* buf_aligned = reinterpret_cast<char*>(((uint64_t)buf + alignment) & (~(alignment - 1))); out->buf = buf; out->offset = buf_aligned - buf; } } // Create DLPack tensor from TensorFlow tensor tensorflow::Status MakeDLTensor(const TensorAsBuf& src, const DLContext& ctx, int64_t* tf_shape, DLTensor* out) { DLDataType dlpack_type; const tensorflow::Tensor& tensor = *src.tensor; auto status = GetDLPackDtype(tensor, &dlpack_type); if (!status.ok()) { return status; } out->ctx = ctx; out->ndim = tensor.shape().dims(); out->shape = tf_shape; out->strides = nullptr; out->byte_offset = 0; out->dtype = dlpack_type; out->data = src.buf + src.offset; return tensorflow::Status::OK(); } template <> class TVMDSOOpTrait<CPUDevice> { public: static const int device_type = kDLCPU; static int device_id(OpKernelContext* context) { return 0; } static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor, tensorflow::TensorShape* output_shape) { tensorflow::int64 num_dims = shape_tensor.NumElements(); const tensorflow::int64* dims = shape_tensor.flat<tensorflow::int64>().data(); tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape); } }; #ifdef TF_TVMDSOOP_ENABLE_GPU template <> class TVMDSOOpTrait<GPUDevice> { public: static const int device_type = kDLGPU; static int device_id(OpKernelContext* context) { auto device_base = context->device(); auto gpu_device_info = device_base->tensorflow_gpu_device_info(); return gpu_device_info->gpu_id; } static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor, tensorflow::TensorShape* output_shape) { tensorflow::int64 num_dims = shape_tensor.NumElements(); const tensorflow::int64* flat = shape_tensor.flat<tensorflow::int64>().data(); tensorflow::int64* dims = new tensorflow::int64[num_dims]; cudaMemcpy(dims, flat, sizeof(tensorflow::int64) * num_dims, cudaMemcpyDeviceToHost); tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape); delete dims; } }; #endif template <typename DEVICE_TYPE> class TVMDSOOp : public OpKernel { private: tvm::runtime::PackedFunc tvm_func; std::string lib_path; std::string func_name; tensorflow::DataType output_dtype; bool has_static_output_shape; std::vector<tensorflow::int64> static_output_shape; void initAttributes(OpKernelConstruction* context) { context->GetAttr("lib_path", &lib_path); context->GetAttr("func_name", &func_name); context->GetAttr("output_dtype", &output_dtype); context->GetAttr("has_static_output_shape", &has_static_output_shape); context->GetAttr("static_output_shape", &static_output_shape); } public: explicit TVMDSOOp(OpKernelConstruction* context) : OpKernel(context) { // Get attr initAttributes(context); // Load TVM function from dynamic library tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile(lib_path); tvm_func = mod_dylib.GetFunction(func_name); CHECK(tvm_func != nullptr); } void Compute(tensorflow::OpKernelContext* context) override { // the last input is output shape spec const int num_inputs = context->num_inputs() - 1; const int num_total_args = num_inputs + 1; std::vector<DLTensor> args(num_total_args); std::vector<TensorAsBuf> buf_info(num_inputs); std::vector<ShapeContainer> shapes(num_inputs); tensorflow::Status status; int device_id = TVMDSOOpTrait<DEVICE_TYPE>::device_id(context); int device_type = TVMDSOOpTrait<DEVICE_TYPE>::device_type; DLContext dl_ctx = {DLDeviceType(device_type), device_id}; // Get output shape tensorflow::TensorShape output_shape; auto& output_shape_tensor = context->input(num_inputs); if (has_static_output_shape) { // use static output shape const tensorflow::int64* dims = static_output_shape.data(); tensorflow::TensorShapeUtils::MakeShape(dims, static_output_shape.size(), &output_shape); } else if (output_shape_tensor.dims() == 1) { // use shape tensor values as output shape TVMDSOOpTrait<DEVICE_TYPE>::make_shape_from_tensor(output_shape_tensor, &output_shape); } else { // use input tensor shape by default output_shape = context->input(0).shape(); } for (int i = 0; i < num_inputs; ++i) { // Grab the input tensor auto& input_tensor = context->input(i); // Create shape container, should keep ref during execution shapes[i] = input_tensor.shape().dim_sizes(); auto shape_ptr = reinterpret_cast<int64_t*>(shapes[i].data()); TensorAsBuf& input = buf_info[i]; input.device_type = device_type; EnsureAlignment(context, input_tensor, &input); input.CopyFromOrigin(); status = MakeDLTensor(input, dl_ctx, shape_ptr, &args[i]); OP_REQUIRES_OK(context, status); } // Allocate output tensor tensorflow::Tensor* output_tensor; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); // shape dimension buf should keel alive on stack auto output_shape_dim_buf = output_tensor->shape().dim_sizes(); auto output_shape_ptr = reinterpret_cast<int64_t*>(output_shape_dim_buf.data()); TensorAsBuf output; output.device_type = device_type; EnsureAlignment(context, *output_tensor, &output); status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]); OP_REQUIRES_OK(context, status); // Prepare PackedFunc arguments std::vector<TVMValue> tvm_values(num_total_args); std::vector<int> tvm_type_codes(num_total_args); TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); for (int k = 0; k < num_total_args; ++k) { setter(k, &args[k]); } TVMRetValue rv; tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv); output.CopyToOrigin(); } }; #ifdef TF_TVMDSOOP_ENABLE_GPU REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>); REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp<GPUDevice>); #else REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>); #endif