Unverified Commit 53a4ad35 by tobe Committed by GitHub

[RUNTIME] Implement TVMDSOOp(TensorFlow custom op) for TVM runtime (#4459)

* Add implementation of TVMDSOOp

* feat: Update cmake script to work with c++11 and in-repo build

* feat: Use libtvm as oplib dependency

* fix: Add missing link dependency to libtvm

* feat: Update tf tvmdso op by review comments

* fix: Update with pr comments

* fix: Fix lint

* feat: Add test script and fix gpu shape

* feat: Add test script and fix gpu shape

* fix: Conditional build tftvm op for gpu

* fix: Conditional build tftvm op for gpu

* fix: Fix pylint of tf_op module.py

* fix: Fix pylint of tf_op module.py

* feat: Conditional enable gpu test for tftvm op

* feat: Conditional enable gpu test for tftvm op

* feat: Add tf_tvmdsoop test script as an app test

* fix: Fix gpu/cpu enabled check on tvm in test script

* fix: Make tf tvmdso op test script runnable with pytest

* remove unused test script test_tfop_module.py

* fix: Remove pushd & popd in tfdsoop test script

* fix: Upgrade tftvmop use python3 to find TensorFlow

* fix: Upgrade tftvmop use python3 to find TensorFlow

* fix: Change target_link_options to target_link_libraries

* fix: Add tftvmop build script's c++ option

* fix: Add tvm library path to tf op test library path

* fix: Debug ci build for tftvm dso op

* fix: Fix cmake error and skip tfop test

* fix: Fix typo and indentation issues

* feat: Use TF list input op def

* fix: Fix style and unexpected changes

Co-authored-by: baoxinqi <baoxinqi@4paradigm.com>
Co-authored-by: Chen Dihao <chendihao@4paradigm.com>
Co-authored-by: wrongtest <wrongtest@4paradigm.com>
parent 4e007632
......@@ -41,6 +41,7 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF)
tvm_option(USE_MICRO "Build with Micro" OFF)
tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
# 3rdparty libraries
tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
......@@ -259,6 +260,7 @@ include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TF_TVMDSOOP.cmake)
if(NOT MSVC)
include(CheckCXXCompilerFlag)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(tf_tvmdsoop C CXX)
set(TFTVM_COMPILE_FLAGS -std=c++11)
set(BUILD_TVMDSOOP_ONLY ON)
set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)
include_directories(${TVM_ROOT}/3rdparty/dlpack/include/)
include_directories(${TVM_ROOT}/3rdparty/dmlc-core/include/)
include_directories(${TVM_ROOT}/include)
link_directories(${TVM_ROOT}/build)
include(${TVM_ROOT}/cmake/util/FindCUDA.cmake)
include(${TVM_ROOT}/cmake/modules/CUDA.cmake)
include(${TVM_ROOT}/cmake/modules/contrib/TF_TVMDSOOP.cmake)
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
TVM_ROOT=$(cd $(dirname $0)/../..; pwd)
echo "TVM_ROOT=${TVM_ROOT}"
export PYTHONPATH=${TVM_ROOT}/python
python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1
if [ "$?" -eq 0 ]; then
echo "Build TF_TVMDSOOP with gpu support and execute tests"
CMAKE_OPTIONS="-DUSE_CUDA=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}"
mkdir -p build
cd build; cmake .. ${CMAKE_OPTIONS} && make
cd ..
LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests
fi
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for tf op module"""
import tempfile
import os
import logging
import tensorflow as tf
import numpy as np
import tvm
from tvm import te
from tvm.contrib import tf_op
def test_use_tvmdso_op():
"""main test function"""
def export_cpu_add_lib():
"""create cpu add op lib"""
n = te.var("n")
ph_a = te.placeholder((n,), name='ph_a')
ph_b = te.placeholder((n,), name='ph_b')
ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
sched = te.create_schedule(ph_c.op)
fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")
lib_path = tempfile.mktemp("tvm_add_dll.so")
fadd_dylib.export_library(lib_path)
return lib_path
def export_gpu_add_lib():
"""create gpu add op lib"""
n = te.var("n")
ph_a = te.placeholder((n,), name='ph_a')
ph_b = te.placeholder((n,), name='ph_b')
ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
sched = te.create_schedule(ph_c.op)
b_axis, t_axis = sched[ph_c].split(ph_c.op.axis[0], factor=64)
sched[ph_c].bind(b_axis, te.thread_axis("blockIdx.x"))
sched[ph_c].bind(t_axis, te.thread_axis("threadIdx.x"))
fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "cuda", name="vector_add")
lib_path = tempfile.mktemp("tvm_add_cuda_dll.so")
fadd_dylib.export_library(lib_path)
return lib_path
def test_add(session, lib_path, tf_device):
"""test add lib with TensorFlow wrapper"""
module = tf_op.OpModule(lib_path)
left = tf.placeholder("float32", shape=[4])
right = tf.placeholder("float32", shape=[4])
feed_dict = {left: [1.0, 2.0, 3.0, 4.0], right: [5.0, 6.0, 7.0, 8.0]}
expect = np.asarray([6.0, 8.0, 10.0, 12.0])
add1 = module.func("vector_add", output_shape=[4], output_dtype="float")
add2 = module.func("vector_add", output_shape=tf.shape(left), output_dtype="float")
add3 = module.func("vector_add", output_shape=[tf.shape(left)[0]], output_dtype="float")
with tf.device(tf_device):
output1 = session.run(add1(left, right), feed_dict)
np.testing.assert_equal(output1, expect)
output2 = session.run(add2(left, right), feed_dict)
np.testing.assert_equal(output2, expect)
output3 = session.run(add3(left, right), feed_dict)
np.testing.assert_equal(output3, expect)
def cpu_test(session):
"""test function for cpu"""
cpu_lib = None
try:
cpu_lib = export_cpu_add_lib()
test_add(session, cpu_lib, "/cpu:0")
finally:
if cpu_lib is not None:
os.remove(cpu_lib)
def gpu_test(session):
"""test function for gpu"""
gpu_lib = None
try:
gpu_lib = export_gpu_add_lib()
test_add(session, gpu_lib, "/gpu:0")
finally:
if gpu_lib is not None:
os.remove(gpu_lib)
with tf.Session() as session:
if tvm.runtime.enabled("cpu"):
logging.info("Test TensorFlow op on cpu kernel")
cpu_test(session)
if tvm.runtime.enabled("gpu"):
logging.info("Test TensorFlow op on gpu kernel")
gpu_test(session)
if __name__ == "__main__":
test_use_tvmdso_op()
......@@ -204,3 +204,7 @@ set(USE_EXAMPLE_EXT_RUNTIME OFF)
# Whether use Thrust
set(USE_THRUST OFF)
# Whether to build the TensorFlow TVMDSOOp module
set(USE_TF_TVMDSOOP OFF)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(NOT USE_TF_TVMDSOOP STREQUAL "OFF")
find_package(Python3 COMPONENTS Interpreter)
execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow compile flags")
endif()
if(NOT USE_CUDA STREQUAL "OFF")
add_definitions(-DTF_TVMDSOOP_ENABLE_GPU)
endif()
execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
OUTPUT_VARIABLE TF_LINK_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow link flags")
endif()
string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}")
separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR})
separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR})
set(OP_LIBRARY_NAME tvm_dso_op)
file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc)
add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS})
set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "")
set(TFTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
if (NOT BUILD_TVMDSOOP_ONLY STREQUAL "ON")
add_dependencies(${OP_LIBRARY_NAME} tvm)
endif()
target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS})
target_link_libraries(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS})
endif()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module container of TensorFlow TVMDSO op"""
from . import module
OpModule = module.OpModule
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module container of TensorFlow TVMDSO op"""
import tensorflow as tf
from tensorflow.python.framework import load_library
class OpModule:
"""Module container of TensorFlow TVMDSO op which wraps exported
TVM op implementation library to be called on TensorFlow side"""
def __init__(self, lib_path):
self.lib_path = lib_path
def func(self, name, output_dtype=None, output_shape=None):
"""Get tvm op function wrapped as TensorFlow tensor to tensor function
Parameters
----------
name: str
function name
output_dtype: str or TensorFlow datatype
Output datatype, default is float32
output_shape: List of integer/tf scalar tensor or tf shape tensor
Output shape, default the same with first input's shape
Returns
----------
Func object that acts as TensorFlow tensor to tensor function.
"""
return TensorFunc(self.lib_path, name, output_dtype, output_shape)
def __getitem__(self, func_name):
return self.func(func_name)
class TensorFunc:
"""Function object that acts as TensorFlow tensor to tensor function."""
def __init__(self, lib_path, func_name, output_dtype, output_shape):
self.lib_path = lib_path
self.func_name = func_name
self.output_dtype = output_dtype
# const(0) indicate invalid dynamic shape
self.dynamic_output_shape = tf.constant(0, tf.int64)
self.static_output_shape = None
self.has_static_output_shape = False # extra flag is required
if self._is_static_shape(output_shape):
self.static_output_shape = output_shape
self.has_static_output_shape = True
elif output_shape is not None:
self.dynamic_output_shape = self._pack_shape_tensor(output_shape)
self.module = load_library.load_op_library('tvm_dso_op.so')
self.tvm_dso_op = self.module.tvm_dso_op
def apply(self, *params):
return self.tvm_dso_op(params,
dynamic_output_shape=self.dynamic_output_shape,
static_output_shape=self.static_output_shape,
has_static_output_shape=self.has_static_output_shape,
lib_path=self.lib_path,
func_name=self.func_name,
output_dtype=self.output_dtype)
def __call__(self, *params):
return self.apply(*params)
def _is_static_shape(self, shape):
if shape is None or not isinstance(shape, list):
return False
for dim_value in shape:
if not isinstance(dim_value, int):
return False
if dim_value < 0:
raise Exception("Negative dimension is illegal: %d" % dim_value)
return True
def _pack_shape_tensor(self, shape):
if isinstance(shape, tf.Tensor):
if shape.dtype == tf.int32:
shape = tf.cast(shape, tf.int64)
elif isinstance(shape, list):
shape_dims = []
for dim_value in shape:
if isinstance(dim_value, int):
shape_dims.append(tf.constant(dim_value, tf.int64))
elif isinstance(dim_value, tf.Tensor) and dim_value.shape.rank == 0:
if dim_value.dtype == tf.int32:
dim_value = tf.cast(dim_value, tf.int64)
shape_dims.append(dim_value)
else:
raise TypeError("Input shape dimension is neither scalar tensor nor int")
shape = tf.stack(shape_dims)
else:
raise TypeError("Input shape is neither tensor nor list")
return shape
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "tensorflow/core/framework/op.h"
REGISTER_OP("TvmDsoOp")
.Input("input_args: ListT")
.Attr("ListT: list({int8, int32, int64, float16, float32})")
.Input("dynamic_output_shape: int64")
.Output("output: output_dtype")
.Attr("lib_path: string")
.Attr("func_name: string")
.Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT")
.Attr("static_output_shape: list(int) >= 0 = []")
.Attr("has_static_output_shape: bool");
......@@ -53,6 +53,9 @@ cd ../..
TVM_FFI=cython python3 -m pytest -v apps/dso_plugin_module
TVM_FFI=ctypes python3 -m pytest -v apps/dso_plugin_module
# Do not enable TensorFlow op
# TVM_FFI=cython sh prepare_and_test_tfop_module.sh
# TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment