Commit 0ee68d72 by Yizhi Liu Committed by Tianqi Chen

[APP] Android RPC (#359)

* [APP] Android RPC first version

* [APP] Android RPC build jni automatically

* [APP] Android OpenCL RPC tested on real devices

* [APP] optimize android app interface. add ndk compile tool

* add ndk compile tool

* [APP] fix android app thread crash; add android test script

* [APP] android app - show alert dialog and disconnect when error occurs

* fix ndk build script code lint

* fix ndk build default argument

* ndk script build remove shell=True. disable android app screen orientation
parent 85bba12c
...@@ -268,6 +268,11 @@ jvmpkg: ...@@ -268,6 +268,11 @@ jvmpkg:
mvn clean package -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \ mvn clean package -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \ -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" $(JVM_TEST_ARGS)) -Dcurrent_libdir="$(ROOTDIR)/lib" $(JVM_TEST_ARGS))
jvminstall:
(cd $(ROOTDIR)/jvm; \
mvn install -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" $(JVM_TEST_ARGS))
clean: clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
......
*.iml
.gradle
/local.properties
/.idea/workspace.xml
/.idea/libraries
.DS_Store
/build
/captures
.externalNativeBuild
apply plugin: 'com.android.application'
task buildJni(type: Exec, description: 'Build JNI libs') {
commandLine 'sh', 'src/main/jni/build.sh'
}
tasks.withType(JavaCompile) {
compileTask -> compileTask.dependsOn buildJni
}
android {
compileSdkVersion 26
buildToolsVersion "26.0.1"
defaultConfig {
applicationId "ml.dmlc.tvm.tvmrpc"
minSdkVersion 17
targetSdkVersion 26
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
sourceSets {
main {
jni.srcDirs = []
jniLibs.srcDirs = ['src/main/libs']
}
}
}
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:26.0.1'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:26.0.1'
compile 'ml.dmlc.tvm:tvm4j-core:0.0.1-SNAPSHOT'
testCompile 'junit:junit:4.12'
}
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="ml.dmlc.tvm.tvmrpc" >
<application
android:allowBackup="true"
android:label="@string/app_name"
android:supportsRtl="true"
android:theme="@style/AppTheme" >
<activity
android:name=".MainActivity"
android:label="@string/app_name"
android:theme="@style/AppTheme.NoActionBar"
android:screenOrientation="portrait">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<uses-permission android:name="android.permission.INTERNET" />
</manifest>
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.tvmrpc;
import android.annotation.SuppressLint;
import android.app.AlertDialog;
import android.content.DialogInterface;
import android.os.Bundle;
import android.os.Handler;
import android.os.Message;
import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar;
import android.widget.CompoundButton;
import android.widget.EditText;
import android.widget.Switch;
public class MainActivity extends AppCompatActivity {
static final int MSG_RPC_ERROR = 0;
static final String MSG_RPC_ERROR_DATA_KEY = "msg_rpc_error_data_key";
private RPCProcessor tvmServerWorker;
@SuppressLint("HandlerLeak")
private final Handler rpcHandler = new Handler() {
@Override
public void dispatchMessage(Message msg) {
Switch switchConnect = findViewById(R.id.switch_connect);
if (msg.what == MSG_RPC_ERROR && switchConnect.isChecked()) {
// switch off and show alert dialog.
switchConnect.setChecked(false);
String msgBody = msg.getData().getString(MSG_RPC_ERROR_DATA_KEY);
showDialog("Error", msgBody);
}
}
};
private void showDialog(String title, String msg) {
AlertDialog.Builder builder = new AlertDialog.Builder(this);
builder.setTitle(title);
builder.setMessage(msg);
builder.setCancelable(true);
builder.setNeutralButton(android.R.string.ok,
new DialogInterface.OnClickListener() {
public void onClick(DialogInterface dialog, int id) {
dialog.cancel();
}
});
builder.create().show();
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Toolbar toolbar = findViewById(R.id.toolbar);
setSupportActionBar(toolbar);
tvmServerWorker = new RPCProcessor(rpcHandler);
tvmServerWorker.setDaemon(true);
tvmServerWorker.start();
Switch switchConnect = findViewById(R.id.switch_connect);
switchConnect.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
@Override
public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
if (isChecked) {
connectProxy();
} else {
disconnect();
}
}
});
}
@Override
protected void onDestroy() {
super.onDestroy();
tvmServerWorker.disconnect();
}
private void connectProxy() {
EditText edProxyAddress = findViewById(R.id.input_address);
EditText edProxyPort = findViewById(R.id.input_port);
EditText edAppKey = findViewById(R.id.input_key);
final String proxyHost = edProxyAddress.getText().toString();
final int proxyPort = Integer.parseInt(edProxyPort.getText().toString());
final String key = edAppKey.getText().toString();
tvmServerWorker.connect(proxyHost, proxyPort, key);
}
private void disconnect() {
tvmServerWorker.disconnect();
System.err.println("Disconnected.");
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.tvmrpc;
import android.os.Bundle;
import android.os.Handler;
import android.os.Message;
import android.os.ParcelFileDescriptor;
import java.net.Socket;
import ml.dmlc.tvm.rpc.ConnectProxyServerProcessor;
import ml.dmlc.tvm.rpc.SocketFileDescriptorGetter;
/**
* Connect to RPC proxy and deal with requests.
*/
class RPCProcessor extends Thread {
private String host;
private int port;
private String key;
private boolean running = false;
private ConnectProxyServerProcessor currProcessor;
private final Handler uiHandler;
static final SocketFileDescriptorGetter socketFdGetter
= new SocketFileDescriptorGetter() {
@Override
public int get(Socket socket) {
return ParcelFileDescriptor.fromSocket(socket).getFd();
}
};
RPCProcessor(Handler uiHandler) {
this.uiHandler = uiHandler;
}
@Override public void run() {
while (true) {
synchronized (this) {
currProcessor = null;
while (!running) {
try {
this.wait();
} catch (InterruptedException e) {
}
}
currProcessor = new ConnectProxyServerProcessor(host, port, key, socketFdGetter);
}
try {
currProcessor.run();
} catch (Throwable e) {
disconnect();
// turn connect switch off.
Message message = new Message();
message.what = MainActivity.MSG_RPC_ERROR;
Bundle bundle = new Bundle();
bundle.putString(MainActivity.MSG_RPC_ERROR_DATA_KEY, e.getMessage());
message.setData(bundle);
uiHandler.sendMessage(message);
}
}
}
/**
* Disconnect from the proxy server.
*/
synchronized void disconnect() {
if (running) {
running = false;
if (currProcessor != null) {
currProcessor.terminate();
}
}
}
/**
* Start rpc processor and connect to the proxy server.
* @param host proxy server host.
* @param port proxy server port.
* @param key proxy server key.
*/
synchronized void connect(String host, int port, String key) {
this.host = host;
this.port = port;
this.key = key;
running = true;
notify();
}
}
LOCAL_PATH := $(call my-dir)
MY_PATH := $(LOCAL_PATH)
include $(CLEAR_VARS)
LOCAL_PATH := $(MY_PATH)
ROOT_PATH := $(MY_PATH)/../../../../../..
ifndef config
ifneq ("$(wildcard ./config.mk)","")
config ?= config.mk
else
config ?= make/config.mk
endif
endif
include $(config)
LOCAL_SRC_FILES := ml_dmlc_tvm_native_c_api.cc
LOCAL_LDFLAGS := -L$(SYSROOT)/usr/lib/ -llog
LOCAL_C_INCLUDES := $(ROOT_PATH)/include \
$(ROOT_PATH)/dlpack/include \
$(ROOT_PATH)/dmlc-core/include \
$(ROOT_PATH)/HalideIR/src \
$(ROOT_PATH)/topi/include
LOCAL_MODULE = tvm4j_runtime_packed
LOCAL_CPP_FEATURES += exceptions
LOCAL_LDLIBS += -latomic
LOCAL_ARM_MODE := arm
ifdef ADD_C_INCLUDES
LOCAL_C_INCLUDES += $(ADD_C_INCLUDES)
endif
ifdef ADD_LDLIBS
LOCAL_LDLIBS += $(ADD_LDLIBS)
endif
include $(BUILD_SHARED_LIBRARY)
ifndef config
ifneq ("$(wildcard ./config.mk)","")
config ?= config.mk
else
config ?= make/config.mk
endif
endif
include $(config)
APP_STL := gnustl_static
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti
ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif
#!/bin/bash
PATH="$PATH:/usr/local/bin"
CURR_DIR=$(cd `dirname $0`; pwd)
ROOT_DIR="$CURR_DIR/../../../../../.."
javah -o $CURR_DIR/ml_dmlc_tvm_native_c_api.h -cp "$ROOT_DIR/jvm/core/target/*" ml.dmlc.tvm.LibInfo || exit -1
cp -f $ROOT_DIR/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc $CURR_DIR/ || exit -1
cp -f $ROOT_DIR/jvm/native/src/main/native/jni_helper_func.h $CURR_DIR/ || exit -1
rm -rf $CURR_DIR/../libs
ndk-build --directory=$CURR_DIR
#-------------------------------------------------------------------------------
# Template configuration for compiling
#
# If you want to change the configuration, please use the following
# steps. Assume you are on the root directory. First copy the this
# file so that any local changes will be ignored by git
#
# cp make/config.mk .
#
# Next modify the according entries, and then compile by
#
# ./build.sh
#
#-------------------------------------------------------------------------------
APP_ABI = all
APP_PLATFORM = android-17
# whether enable OpenCL during compile
USE_OPENCL = 0
# the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc
ADD_C_INCLUDES =
# the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so
ADD_LDLIBS =
/*!
* Copyright (c) 2017 by Contributors
* \file tvm_runtime.h
* \brief Pack all tvm runtime source files
*/
#include <sys/stat.h>
#include <fstream>
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc"
#include "../src/runtime/system_lib_module.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc"
#include "../src/runtime/rpc/rpc_module.cc"
#include "../src/runtime/rpc/rpc_socket_impl.cc"
#include "../src/runtime/thread_pool.cc"
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
#include "../src/runtime/opencl/opencl_module.cc"
#endif
<?xml version="1.0" encoding="utf-8"?>
<android.support.design.widget.CoordinatorLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="ml.dmlc.tvm.tvmrpc.MainActivity">
<android.support.design.widget.AppBarLayout
android:layout_height="wrap_content"
android:layout_width="match_parent"
android:theme="@style/AppTheme.AppBarOverlay">
<android.support.v7.widget.Toolbar
android:id="@+id/toolbar"
android:layout_width="match_parent"
android:layout_height="?attr/actionBarSize"
android:background="?attr/colorPrimary"
app:popupTheme="@style/AppTheme.PopupOverlay" />
</android.support.design.widget.AppBarLayout>
<include layout="@layout/content_main"/>
</android.support.design.widget.CoordinatorLayout>
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:orientation="vertical"
android:layout_width="fill_parent"
android:layout_height="wrap_content"
app:layout_behavior="@string/appbar_scrolling_view_behavior"
tools:showIn="@layout/activity_main">
<LinearLayout
android:orientation="horizontal"
android:layout_width="fill_parent"
android:layout_height="wrap_content">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/label_address"/>
<EditText
android:id="@+id/input_address"
android:hint="@string/input_address"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@android:drawable/editbox_background"/>
</LinearLayout>
<LinearLayout
android:orientation="horizontal"
android:layout_width="fill_parent"
android:layout_height="wrap_content">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/label_port"/>
<EditText
android:id="@+id/input_port"
android:hint="@string/input_port"
android:minWidth="100dip"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@android:drawable/editbox_background"/>
</LinearLayout>
<LinearLayout
android:orientation="horizontal"
android:layout_width="fill_parent"
android:layout_height="wrap_content">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/label_key"/>
<EditText
android:id="@+id/input_key"
android:hint="@string/input_key"
android:minWidth="100dip"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@android:drawable/editbox_background"/>
</LinearLayout>
<LinearLayout
android:orientation="horizontal"
android:layout_width="fill_parent"
android:layout_height="wrap_content">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/label_connect"/>
<Switch
android:id="@+id/switch_connect"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:switchMinWidth="55dp"
android:paddingLeft="10dip"
android:checked="false"
android:textOff="@string/switch_off"
android:textOn="@string/switch_on" />
</LinearLayout>
</LinearLayout>
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#3F51B5</color>
<color name="colorPrimaryDark">#303F9F</color>
<color name="colorAccent">#06d467</color>
</resources>
<resources>
<string name="app_name">TVM RPC</string>
<string name="input_address">Enter the proxy server address</string>
<string name="input_port">Enter the proxy server port</string>
<string name="input_key">Enter the app connection key</string>
<string name="label_address">Address</string>
<string name="label_port">Port</string>
<string name="label_key">Key</string>
<string name="label_connect">Connect to Proxy</string>
<string name="switch_on">Connected</string>
<string name="switch_off">Disconnected</string>
</resources>
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
<style name="AppTheme.NoActionBar">
<item name="windowActionBar">false</item>
<item name="windowNoTitle">true</item>
</style>
<style name="AppTheme.AppBarOverlay" parent="ThemeOverlay.AppCompat.Dark.ActionBar" />
<style name="AppTheme.PopupOverlay" parent="ThemeOverlay.AppCompat.Light" />
</resources>
// Top-level build file where you can add configuration options common to all sub-projects/modules.
buildscript {
repositories {
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:2.3.3'
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
}
}
allprojects {
repositories {
jcenter()
maven {
url 'https://maven.google.com'
}
mavenLocal()
mavenCentral()
}
}
task clean(type: Delete) {
delete rootProject.buildDir
}
#!/bin/bash
CURR_DIR=$(cd `dirname $0`; pwd)
keytool -genkey -keystore $CURR_DIR/tvmrpc.keystore -alias tvmrpc -keyalg RSA -validity 10000
#!/bin/bash
CURR_DIR=$(cd `dirname $0`; pwd)
APK_DIR=$CURR_DIR/../app/build/outputs/apk
UNSIGNED_APK=$APK_DIR/app-release-unsigned.apk
SIGNED_APK=$APK_DIR/tvmrpc-release.apk
jarsigner -verbose -keystore tvmrpc.keystore -signedjar $SIGNED_APK $UNSIGNED_APK 'tvmrpc'
#Mon Aug 14 21:31:55 CST 2017
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip
#!/usr/bin/env bash
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS=""
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn ( ) {
echo "$*"
}
die ( ) {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
esac
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin, switch paths to Windows format before running java
if $cygwin ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=$((i+1))
done
case $i in
(0) set -- ;;
(1) set -- "$args0" ;;
(2) set -- "$args0" "$args1" ;;
(3) set -- "$args0" "$args1" "$args2" ;;
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
function splitJvmOpts() {
JVM_OPTS=("$@")
}
eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS=
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto init
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto init
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:init
@rem Get command-line arguments, handling Windowz variants
if not "%OS%" == "Windows_NT" goto win9xME_args
if "%@eval[2+2]" == "4" goto 4NT_args
:win9xME_args
@rem Slurp the command line arguments.
set CMD_LINE_ARGS=
set _SKIP=2
:win9xME_args_slurp
if "x%~1" == "x" goto execute
set CMD_LINE_ARGS=%*
goto execute
:4NT_args
@rem Get arguments from the 4NT Shell from JP Software
set CMD_LINE_ARGS=%$
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
"""Testcode for Android RPC.
To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
And configure the proxy host field as commented.
"""
import tvm
import os
from tvm.contrib import rpc, util, ndk, rpc_proxy
import numpy as np
# Set to be address of tvm proxy.
proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"]
proxy_port = 9090
key = "android"
# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
target = "llvm -target=%s-linux-android" % arch
def test_rpc_module():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
temp = util.tempdir()
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "opencl", target_host=target, name="myadd")
path_dso1 = temp.relpath("dev_lib.so")
f.export_library(path_dso1, ndk.create_shared)
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].parallel(xi)
s[B].pragma(xo, "parallel_launch_point")
s[B].pragma(xi, "parallel_barrier_when_finish")
f = tvm.build(s, [A, B], target, name="myadd_cpu")
path_dso2 = temp.relpath("cpu_lib.so")
f.export_library(path_dso2, ndk.create_shared)
# connect to the proxy
remote = rpc.connect(proxy_host, proxy_port, key=key)
print('Run GPU test ...')
ctx = remote.cl(0)
remote.upload(path_dso1)
f1 = remote.load_module("dev_lib.so")
a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
print('Run CPU test ...')
ctx = remote.cpu(0)
remote.upload(path_dso2)
f2 = remote.load_module("cpu_lib.so")
a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f2.time_evaluator(f2.entry_name, ctx, number=10)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
if __name__ == "__main__":
test_rpc_module()
...@@ -60,6 +60,7 @@ final class Base { ...@@ -60,6 +60,7 @@ final class Base {
public static final LibInfo _LIB = new LibInfo(); public static final LibInfo _LIB = new LibInfo();
static { static {
boolean loadNativeRuntimeLib = true;
try { try {
try { try {
tryLoadLibraryOS("tvm4j"); tryLoadLibraryOS("tvm4j");
...@@ -72,34 +73,49 @@ final class Base { ...@@ -72,34 +73,49 @@ final class Base {
NativeLibraryLoader.loadLibrary("tvm4j"); NativeLibraryLoader.loadLibrary("tvm4j");
} }
} catch (Throwable e) { } catch (Throwable e) {
System.err.println("[ERROR] Couldn't find native library tvm4j"); System.err.println("[WARN] Couldn't find native library tvm4j.");
throw new RuntimeException(e); e.printStackTrace();
System.err.println("Try to load tvm4j (runtime packed version) ...");
try {
System.loadLibrary("tvm4j_runtime_packed");
// if tvm runtime is packed in libtvm4j, we do not need to dlopen libtvm_runtime.so.
loadNativeRuntimeLib = false;
} catch (UnsatisfiedLinkError errFull) {
System.err.println("[ERROR] Couldn't find native library tvm4j_runtime_packed.");
throw new RuntimeException(errFull);
}
} }
String tvmLibFilename = System.getProperty("libtvm.so.path"); System.err.println("libtvm4j loads successfully.");
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) { if (loadNativeRuntimeLib) {
try { String tvmLibFilename = System.getProperty("libtvm.so.path");
String runtimeLibname; if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
String os = System.getProperty("os.name"); || _LIB.nativeLibInit(tvmLibFilename) != 0) {
// ref: http://lopica.sourceforge.net/os.html try {
if (os.startsWith("Linux")) { String runtimeLibname;
runtimeLibname = "libtvm_runtime.so"; String os = System.getProperty("os.name");
} else if (os.startsWith("Mac")) { // ref: http://lopica.sourceforge.net/os.html
runtimeLibname = "libtvm_runtime.dylib"; if (os.startsWith("Linux")) {
} else { runtimeLibname = "libtvm_runtime.so";
// TODO(yizhi) support windows later } else if (os.startsWith("Mac")) {
throw new UnsatisfiedLinkError("Windows not supported currently"); runtimeLibname = "libtvm_runtime.dylib";
} } else {
NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() { // TODO(yizhi) support windows later
@Override public void invoke(File target) { throw new UnsatisfiedLinkError(os + " not supported currently");
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
} }
}); NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() {
} catch (IOException e) { @Override public void invoke(File target) {
throw new RuntimeException(e); System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
} else {
_LIB.nativeLibInit(null);
} }
Runtime.getRuntime().addShutdownHook(new Thread() { Runtime.getRuntime().addShutdownHook(new Thread() {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
/**
* Server processor for proxy connection.
*/
public class ConnectProxyServerProcessor implements ServerProcessor {
private final String host;
private final int port;
private final String key;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private volatile Socket currSocket = new Socket();
/**
* Construct proxy server processor.
* @param host Proxy server host.
* @param port Proxy server port.
* @param key Proxy server key.
* @param sockFdGetter Method to get file descriptor from Java socket.
*/
public ConnectProxyServerProcessor(String host, int port, String key,
SocketFileDescriptorGetter sockFdGetter) {
this.host = host;
this.port = port;
this.key = "server:" + key;
socketFileDescriptorGetter = sockFdGetter;
}
/**
* Close the socket.
*/
@Override public void terminate() {
Utils.closeQuietly(currSocket);
}
@Override public void run() {
try {
SocketAddress address = new InetSocketAddress(host, port);
currSocket.connect(address, 6000);
InputStream in = currSocket.getInputStream();
OutputStream out = currSocket.getOutputStream();
out.write(Utils.toBytes(RPC.RPC_MAGIC));
out.write(Utils.toBytes(key.length()));
out.write(Utils.toBytes(key));
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic == RPC.RPC_MAGIC + 1) {
throw new RuntimeException(
String.format("key: %s has already been used in proxy", key));
} else if (magic == RPC.RPC_MAGIC + 2) {
System.err.println("RPCProxy do not have matching client key " + key);
} else if (magic != RPC.RPC_MAGIC) {
throw new RuntimeException(address + " is not RPC Proxy");
}
System.err.println("RPCProxy connected to " + address);
final int sockFd = socketFileDescriptorGetter.get(currSocket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + address);
}
} catch (Throwable e) {
e.printStackTrace();
throw new RuntimeException(e);
} finally {
terminate();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.TVMValue;
import java.io.File;
import java.io.IOException;
/**
* Call native ServerLoop on socket file descriptor.
*/
public class NativeServerLoop implements Runnable {
private final int sockFd;
/**
* Constructor for NativeServerLoop
* @param nativeSockFd native socket file descriptor.
*/
public NativeServerLoop(final int nativeSockFd) {
sockFd = nativeSockFd;
}
@Override public void run() {
File tempDir = null;
try {
tempDir = serverEnv();
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (tempDir != null) {
String[] entries = tempDir.list();
for (String s : entries) {
File currentFile = new File(tempDir.getPath(), s);
if (!currentFile.delete()) {
System.err.println(
"[WARN] Couldn't delete temporary file " + currentFile.getAbsolutePath());
}
}
if (!tempDir.delete()) {
System.err.println(
"[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath());
}
}
}
}
private static File serverEnv() throws IOException {
// Server environment function return temp dir.
final File tempDir = File.createTempFile("tvm4j_rpc_", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return tempDir + File.separator + args[0].asString();
}
}, true);
Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
String filename = args[0].asString();
String path = tempDir + File.separator + filename;
System.err.println("Load module from " + path);
return Module.load(path);
}
}, true);
return tempDir;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
/**
* Abstract runnable class for RPC server process.
*/
public interface ServerProcessor extends Runnable {
public void terminate();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.net.Socket;
/**
* Interface for defining different socket fd getter.
*/
public interface SocketFileDescriptorGetter {
/**
* Get native socket file descriptor.
* @param socket Java socket.
* @return native socket fd.
*/
public int get(Socket socket);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
/**
* Server processor for standalone.
*/
public class StandaloneServerProcessor implements ServerProcessor {
private final ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
public StandaloneServerProcessor(int serverPort,
SocketFileDescriptorGetter sockFdGetter) throws IOException {
this.server = new ServerSocket(serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
}
@Override public void terminate() {
try {
server.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override public void run() {
try {
Socket socket = server.accept();
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) {
Utils.closeQuietly(socket);
return;
}
int keyLen = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
String key = Utils.decodeToStr(Utils.recvAll(in, keyLen));
if (!key.startsWith("client:")) {
out.write(Utils.toBytes(RPC.RPC_MAGIC + 2));
} else {
out.write(Utils.toBytes(RPC.RPC_MAGIC));
}
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
Utils.closeQuietly(socket);
} catch (Throwable e) {
e.printStackTrace();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
* Utilities for RPC.
*/
class Utils {
public static byte[] recvAll(final InputStream in, final int numBytes) throws IOException {
byte[] res = new byte[numBytes];
int numRead = 0;
while (numRead < numBytes) {
int chunk = in.read(res, numRead, Math.min(numBytes - numRead, 1024));
numRead += chunk;
}
return res;
}
public static void closeQuietly(Socket socket) {
if (socket != null) {
try {
socket.shutdownInput();
socket.shutdownOutput();
socket.close();
} catch (IOException ioe) {
// close quietly, do nothing.
}
}
}
public static ByteBuffer wrapBytes(byte[] bytes) {
ByteBuffer bb = ByteBuffer.wrap(bytes);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
}
public static byte[] toBytes(int number) {
ByteBuffer bb = ByteBuffer.allocate(4);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb.putInt(number).array();
}
public static byte[] toBytes(String str) {
byte[] bytes = new byte[str.length()];
for (int i = 0; i < str.length(); ++i) {
bytes[i] = (byte) str.charAt(i);
}
return bytes;
}
public static String decodeToStr(byte[] bytes) {
StringBuilder builder = new StringBuilder();
for (byte bt : bytes) {
builder.append((char) bt);
}
return builder.toString();
}
}
...@@ -181,6 +181,7 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { ...@@ -181,6 +181,7 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
default: default:
LOG(FATAL) << "Do NOT know how to handle return type code " << tcode; LOG(FATAL) << "Do NOT know how to handle return type code " << tcode;
} }
return NULL;
} }
#endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
...@@ -4,10 +4,14 @@ ...@@ -4,10 +4,14 @@
* \brief tvm4j jni source file * \brief tvm4j jni source file
*/ */
#include "ml_dmlc_tvm_native_c_api.h" // generated by javah #include "ml_dmlc_tvm_native_c_api.h" // generated by javah
#ifdef TVM4J_ANDROID
#include "tvm_runtime.h"
#else
#include <dlfcn.h> #include <dlfcn.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#endif
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
...@@ -16,7 +20,7 @@ ...@@ -16,7 +20,7 @@
#include "jni_helper_func.h" #include "jni_helper_func.h"
JavaVM *_jvm; JavaVM *_jvm;
void *_tvmHandle; void *_tvmHandle = nullptr;
struct TVMFuncArgsThreadLocalEntry { struct TVMFuncArgsThreadLocalEntry {
std::vector<TVMValue> tvmFuncArgValues; std::vector<TVMValue> tvmFuncArgValues;
std::vector<int> tvmFuncArgTypes; std::vector<int> tvmFuncArgTypes;
...@@ -28,7 +32,7 @@ typedef dmlc::ThreadLocalStore<TVMFuncArgsThreadLocalEntry> TVMFuncArgsThreadLoc ...@@ -28,7 +32,7 @@ typedef dmlc::ThreadLocalStore<TVMFuncArgsThreadLocalEntry> TVMFuncArgsThreadLoc
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_nativeLibInit JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_nativeLibInit
(JNIEnv *env, jobject obj, jstring jtvmLibFile) { (JNIEnv *env, jobject obj, jstring jtvmLibFile) {
if (_tvmHandle == NULL) { if (_tvmHandle == NULL && !env->IsSameObject(jtvmLibFile, NULL)) {
const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0);
_tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL); _tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL);
env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile); env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile);
...@@ -127,7 +131,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncListGlobalNames( ...@@ -127,7 +131,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncListGlobalNames(
// fill names // fill names
for (int i = 0; i < outSize; ++i) { for (int i = 0; i < outSize; ++i) {
jstring jname = env->NewStringUTF(outArray[i]); jstring jname = env->NewStringUTF(outArray[i]);
env->CallObjectMethod(jfuncNames, arrayAppend, jname); env->CallBooleanMethod(jfuncNames, arrayAppend, jname);
env->DeleteLocalRef(jname); env->DeleteLocalRef(jname);
} }
...@@ -203,7 +207,11 @@ extern "C" int funcInvokeCallback(TVMValue *args, ...@@ -203,7 +207,11 @@ extern "C" int funcInvokeCallback(TVMValue *args,
JNIEnv *env; JNIEnv *env;
int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6); int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6);
if (jniStatus == JNI_EDETACHED) { if (jniStatus == JNI_EDETACHED) {
#ifdef TVM4J_ANDROID
_jvm->AttachCurrentThread(&env, nullptr);
#else
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr); _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr);
#endif
} else { } else {
CHECK(jniStatus == JNI_OK); CHECK(jniStatus == JNI_OK);
} }
...@@ -273,7 +281,11 @@ extern "C" void funcFreeCallback(void *resourceHandle) { ...@@ -273,7 +281,11 @@ extern "C" void funcFreeCallback(void *resourceHandle) {
JNIEnv *env; JNIEnv *env;
int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6); int jniStatus = _jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6);
if (jniStatus == JNI_EDETACHED) { if (jniStatus == JNI_EDETACHED) {
#ifdef TVM4J_ANDROID
_jvm->AttachCurrentThread(&env, nullptr);
#else
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr); _jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), nullptr);
#endif
} else { } else {
CHECK(jniStatus == JNI_OK); CHECK(jniStatus == JNI_OK);
} }
......
"""Util to invoke NDK compiler toolchain."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import subprocess
import os
def create_shared(output,
objects,
options=None):
"""Create shared library.
Parameters
----------
output : str
The target shared library.
objects : list
List of object files.
options : list of str, optional
The additional options.
"""
if "TVM_NDK_CC" not in os.environ:
raise RuntimeError("Require environment variable TVM_NDK_CC"
" to be the NDK standalone compiler")
compiler = os.environ["TVM_NDK_CC"]
cmd = [compiler]
cmd += ["-o", output]
if isinstance(objects, str):
cmd += [objects]
else:
cmd += objects
options = options if options else ["-shared", "-fPIC"]
cmd += options
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += out
raise RuntimeError(msg)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment