Commit fde98f0e by eqy Committed by Tianqi Chen

[RPC] Android RPC Performance Regression Fix, Update Android RPC to use Tracker (#1457)

parent 30409045
...@@ -51,15 +51,15 @@ Here's a piece of example for `config.mk`. ...@@ -51,15 +51,15 @@ Here's a piece of example for `config.mk`.
```makefile ```makefile
APP_ABI = arm64-v8a APP_ABI = arm64-v8a
APP_PLATFORM = android-17 APP_PLATFORM = android-17
# whether enable OpenCL during compile # whether enable OpenCL during compile
USE_OPENCL = 1 USE_OPENCL = 1
# the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc # the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc
ADD_C_INCLUDES = /opt/adrenosdk-osx/Development/Inc ADD_C_INCLUDES = /opt/adrenosdk-osx/Development/Inc
# the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so # the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so
ADD_LDLIBS = libOpenCL.so ADD_LDLIBS = libOpenCL.so
``` ```
...@@ -85,13 +85,14 @@ If everything goes well, you will find compile tools in `/opt/android-toolchain- ...@@ -85,13 +85,14 @@ If everything goes well, you will find compile tools in `/opt/android-toolchain-
### Cross Compile and Upload to the Android Device ### Cross Compile and Upload to the Android Device
First start a proxy server using `python -m tvm.exec.rpc_proxy` and make your Android device connect to this proxy server via TVM RPC application. First start an RPC tracker using `python -m tvm.exec.rpc_tracker --port [PORT]` and make your Android device connect to this RPC tracker via TVM RPC application.
Then checkout [android\_rpc/tests/android\_rpc\_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py) and run, Then checkout [android\_rpc/tests/android\_rpc\_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py) and run,
```bash ```bash
# Specify the proxy host # Specify the RPC tracker
export TVM_ANDROID_RPC_PROXY_HOST=0.0.0.0 export TVM_TRACKER_HOST=0.0.0.0
export TVM_TRACKER_PORT=[PORT]
# Specify the standalone Android C++ compiler # Specify the standalone Android C++ compiler
export TVM_NDK_CC=/opt/android-toolchain-arm64/bin/aarch64-linux-android-g++ export TVM_NDK_CC=/opt/android-toolchain-arm64/bin/aarch64-linux-android-g++
python android_rpc_test.py python android_rpc_test.py
......
...@@ -13,7 +13,7 @@ android { ...@@ -13,7 +13,7 @@ android {
buildToolsVersion "26.0.1" buildToolsVersion "26.0.1"
defaultConfig { defaultConfig {
applicationId "ml.dmlc.tvm.tvmrpc" applicationId "ml.dmlc.tvm.tvmrpc"
minSdkVersion 17 minSdkVersion 24
targetSdkVersion 26 targetSdkVersion 26
versionCode 1 versionCode 1
versionName "1.0" versionName "1.0"
......
...@@ -20,9 +20,16 @@ ...@@ -20,9 +20,16 @@
<category android:name="android.intent.category.LAUNCHER" /> <category android:name="android.intent.category.LAUNCHER" />
</intent-filter> </intent-filter>
</activity> </activity>
<service android:name=".RPCService" <service android:name=".RPCWatchdogService"
android:process=":RPCServiceProcess" android:process=":RPCWatchdogServiceProcess"
android:permission="android.permission.BIND_JOB_SERVICE" /> android:permission="android.permission.BIND_JOB_SERVICE" />
<activity
android:name=".RPCActivity"
android:process=":RPCProcess"
android:label="@string/rpc_name"
android:theme="@style/AppTheme.NoActionBar"
android:screenOrientation="portrait">
</activity>
</application> </application>
</manifest> </manifest>
...@@ -31,12 +31,18 @@ import android.support.v7.widget.Toolbar; ...@@ -31,12 +31,18 @@ import android.support.v7.widget.Toolbar;
import android.widget.CompoundButton; import android.widget.CompoundButton;
import android.widget.EditText; import android.widget.EditText;
import android.widget.Switch; import android.widget.Switch;
import android.widget.Button;
import android.view.View;
import android.content.Intent; import android.content.Intent;
import android.app.NotificationChannel;
import android.app.NotificationManager;
public class MainActivity extends AppCompatActivity { public class MainActivity extends AppCompatActivity {
private boolean skipRelaunch = true;
// wait time before automatic restart of RPC Activity
public static final int HANDLER_RESTART_DELAY = 5000;
private RPCWatchdog watchdog;
private void showDialog(String title, String msg) { private void showDialog(String title, String msg) {
AlertDialog.Builder builder = new AlertDialog.Builder(this); AlertDialog.Builder builder = new AlertDialog.Builder(this);
...@@ -52,73 +58,107 @@ public class MainActivity extends AppCompatActivity { ...@@ -52,73 +58,107 @@ public class MainActivity extends AppCompatActivity {
builder.create().show(); builder.create().show();
} }
public Intent updateRPCPrefs() {
System.err.println("updating preferences...");
EditText edProxyAddress = findViewById(R.id.input_address);
EditText edProxyPort = findViewById(R.id.input_port);
EditText edAppKey = findViewById(R.id.input_key);
Switch inputSwitch = findViewById(R.id.switch_persistent);
final String proxyHost = edProxyAddress.getText().toString();
final int proxyPort = Integer.parseInt(edProxyPort.getText().toString());
final String key = edAppKey.getText().toString();
final boolean isChecked = inputSwitch.isChecked();
SharedPreferences pref = getApplicationContext().getSharedPreferences("RPCProxyPreference", Context.MODE_PRIVATE);
SharedPreferences.Editor editor = pref.edit();
editor.putString("input_address", proxyHost);
editor.putString("input_port", edProxyPort.getText().toString());
editor.putString("input_key", key);
editor.putBoolean("input_switch", isChecked);
editor.commit();
Intent intent = new Intent(this, RPCActivity.class);
intent.putExtra("host", proxyHost);
intent.putExtra("port", proxyPort);
intent.putExtra("key", key);
return intent;
}
private void setupRelaunch() {
final Context context = this;
final Switch switchPersistent = findViewById(R.id.switch_persistent);
final Runnable rPCStarter = new Runnable() {
public void run() {
if (switchPersistent.isChecked()) {
System.err.println("relaunching RPC activity in 5s...");
Intent intent = ((MainActivity) context).updateRPCPrefs();
startActivity(intent);
}
}
};
Handler handler = new Handler();
handler.postDelayed(rPCStarter, HANDLER_RESTART_DELAY);
}
@Override @Override
protected void onCreate(Bundle savedInstanceState) { protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState); super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main); setContentView(R.layout.activity_main);
Toolbar toolbar = findViewById(R.id.toolbar); Toolbar toolbar = findViewById(R.id.toolbar);
setSupportActionBar(toolbar); setSupportActionBar(toolbar);
final Context context = this;
Switch switchConnect = findViewById(R.id.switch_connect); Switch switchPersistent = findViewById(R.id.switch_persistent);
switchConnect.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { switchPersistent.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
@Override @Override
public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
if (isChecked) { if (isChecked) {
enableInputView(false); System.err.println("automatic RPC restart enabled...");
connectProxy(); updateRPCPrefs();
} else { } else {
disconnect(); System.err.println("automatic RPC restart disabled...");
enableInputView(true); updateRPCPrefs();
} }
} }
}); });
enableInputView(true);
Button startRPC = findViewById(R.id.button_start_rpc);
startRPC.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
Intent intent = ((MainActivity) context).updateRPCPrefs();
startActivity(intent);
}
});
enableInputView(true);
} }
@Override @Override
protected void onDestroy() { protected void onResume() {
super.onDestroy(); System.err.println("MainActivity onResume...");
if (watchdog != null) { System.err.println("skipRelaunch: " + skipRelaunch);
watchdog.disconnect(); // if this is the first time onResume is called, do nothing, otherwise we
watchdog = null; // may double launch
if (!skipRelaunch) {
enableInputView(true);
setupRelaunch();
} else {
skipRelaunch = false;
} }
super.onResume();
} }
private void connectProxy() { @Override
EditText edProxyAddress = findViewById(R.id.input_address); protected void onDestroy() {
EditText edProxyPort = findViewById(R.id.input_port); super.onDestroy();
EditText edAppKey = findViewById(R.id.input_key);
final String proxyHost = edProxyAddress.getText().toString();
final int proxyPort = Integer.parseInt(edProxyPort.getText().toString());
final String key = edAppKey.getText().toString();
System.err.println("creating watchdog thread...");
watchdog = new RPCWatchdog(proxyHost, proxyPort, key, this);
System.err.println("starting watchdog thread...");
watchdog.start();
SharedPreferences pref = getApplicationContext().getSharedPreferences("RPCProxyPreference", Context.MODE_PRIVATE);
SharedPreferences.Editor editor = pref.edit();
editor.putString("input_address", proxyHost);
editor.putString("input_port", edProxyPort.getText().toString());
editor.putString("input_key", key);
editor.commit();
}
private void disconnect() {
if (watchdog != null) {
watchdog.disconnect();
watchdog = null;
}
} }
private void enableInputView(boolean enable) { private void enableInputView(boolean enable) {
EditText edProxyAddress = findViewById(R.id.input_address); EditText edProxyAddress = findViewById(R.id.input_address);
EditText edProxyPort = findViewById(R.id.input_port); EditText edProxyPort = findViewById(R.id.input_port);
EditText edAppKey = findViewById(R.id.input_key); EditText edAppKey = findViewById(R.id.input_key);
Switch input_switch = findViewById(R.id.switch_persistent);
edProxyAddress.setEnabled(enable); edProxyAddress.setEnabled(enable);
edProxyPort.setEnabled(enable); edProxyPort.setEnabled(enable);
edAppKey.setEnabled(enable); edAppKey.setEnabled(enable);
...@@ -134,6 +174,8 @@ public class MainActivity extends AppCompatActivity { ...@@ -134,6 +174,8 @@ public class MainActivity extends AppCompatActivity {
String inputKey = pref.getString("input_key", null); String inputKey = pref.getString("input_key", null);
if (null != inputKey) if (null != inputKey)
edAppKey.setText(inputKey); edAppKey.setText(inputKey);
boolean isChecked = pref.getBoolean("input_switch", false);
input_switch.setChecked(isChecked);
} }
} }
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.tvmrpc;
import android.os.Bundle;
import android.support.v7.app.AppCompatActivity;
import android.content.Intent;
import android.widget.Button;
import android.view.View;
public class RPCActivity extends AppCompatActivity {
private RPCProcessor tvmServerWorker;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_rpc);
Button stopRPC = findViewById(R.id.button_stop_rpc);
stopRPC.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
System.err.println(tvmServerWorker == null);
if (tvmServerWorker != null) {
// currently will raise a socket closed exception
tvmServerWorker.disconnect();
}
finish();
// prevent Android from recycling the process
System.exit(0);
}
});
System.err.println("rpc activity onCreate...");
Intent intent = getIntent();
String host = intent.getStringExtra("host");
int port = intent.getIntExtra("port", 9090);
String key = intent.getStringExtra("key");
tvmServerWorker = new RPCProcessor();
tvmServerWorker.setDaemon(true);
tvmServerWorker.start();
tvmServerWorker.connect(host, port, key);
}
@Override
protected void onDestroy() {
System.err.println("rpc activity onDestroy");
tvmServerWorker.disconnect();
super.onDestroy();
}
}
...@@ -17,15 +17,11 @@ ...@@ -17,15 +17,11 @@
package ml.dmlc.tvm.tvmrpc; package ml.dmlc.tvm.tvmrpc;
import android.os.Bundle;
import android.os.Handler;
import android.os.Message;
import android.os.ParcelFileDescriptor; import android.os.ParcelFileDescriptor;
import java.net.Socket; import java.net.Socket;
import ml.dmlc.tvm.rpc.ConnectTrackerServerProcessor;
import ml.dmlc.tvm.rpc.ConnectProxyServerProcessor;
import ml.dmlc.tvm.rpc.SocketFileDescriptorGetter; import ml.dmlc.tvm.rpc.SocketFileDescriptorGetter;
import ml.dmlc.tvm.rpc.RPCWatchdog;
/** /**
* Connect to RPC proxy and deal with requests. * Connect to RPC proxy and deal with requests.
...@@ -36,9 +32,8 @@ class RPCProcessor extends Thread { ...@@ -36,9 +32,8 @@ class RPCProcessor extends Thread {
private String key; private String key;
private boolean running = false; private boolean running = false;
private long startTime; private long startTime;
private ConnectProxyServerProcessor currProcessor; private ConnectTrackerServerProcessor currProcessor;
private boolean kill = false; private boolean first = true;
public static final int SESSION_TIMEOUT = 30000;
static final SocketFileDescriptorGetter socketFdGetter static final SocketFileDescriptorGetter socketFdGetter
= new SocketFileDescriptorGetter() { = new SocketFileDescriptorGetter() {
...@@ -47,21 +42,10 @@ class RPCProcessor extends Thread { ...@@ -47,21 +42,10 @@ class RPCProcessor extends Thread {
return ParcelFileDescriptor.fromSocket(socket).getFd(); return ParcelFileDescriptor.fromSocket(socket).getFd();
} }
}; };
// callback to initialize the start time of an rpc session
class setTimeCallback implements Runnable {
private RPCProcessor rPCProcessor;
public setTimeCallback(RPCProcessor rPCProcessor) {
this.rPCProcessor = rPCProcessor;
}
@Override
public void run() {
rPCProcessor.setStartTime();
}
}
@Override public void run() { @Override public void run() {
RPCWatchdog watchdog = new RPCWatchdog();
watchdog.start();
while (true) { while (true) {
synchronized (this) { synchronized (this) {
currProcessor = null; currProcessor = null;
...@@ -71,49 +55,18 @@ class RPCProcessor extends Thread { ...@@ -71,49 +55,18 @@ class RPCProcessor extends Thread {
} catch (InterruptedException e) { } catch (InterruptedException e) {
} }
} }
// if kill, we do nothing and wait for app restart try {
// to prevent race where timedOut was reported but restart has not currProcessor = new ConnectTrackerServerProcessor(host, port, key, socketFdGetter, watchdog);
// happened yet } catch (Throwable e) {
if (kill) { e.printStackTrace();
System.err.println("waiting for restart..."); // kill if creating a new processor failed
currProcessor = null; System.exit(0);
}
else {
startTime = 0;
currProcessor = new ConnectProxyServerProcessor(host, port, key, socketFdGetter);
currProcessor.setStartTimeCallback(new setTimeCallback(this));
} }
} }
if (currProcessor != null) if (currProcessor != null)
currProcessor.run(); currProcessor.run();
} watchdog.finishTimeout();
}
/**
* check if the current RPCProcessor has timed out while in a session
*/
synchronized boolean timedOut(long curTime) {
if (startTime == 0) {
return false;
} }
else if ((curTime - startTime) > SESSION_TIMEOUT) {
System.err.println("set kill flag...");
kill = true;
return true;
}
return false;
}
/**
* set the start time of the current RPC session (used in callback)
*/
synchronized void setStartTime() {
startTime = System.currentTimeMillis();
System.err.println("start time set to: " + startTime);
}
synchronized long getStartTime() {
return startTime;
} }
/** /**
...@@ -139,6 +92,6 @@ class RPCProcessor extends Thread { ...@@ -139,6 +92,6 @@ class RPCProcessor extends Thread {
this.port = port; this.port = port;
this.key = key; this.key = key;
running = true; running = true;
notify(); this.notify();
} }
} }
package ml.dmlc.tvm.tvmrpc;
import android.app.Service;
import android.os.IBinder;
import android.content.Intent;
public class RPCService extends Service {
private String host;
private int port;
private String key;
private int intentNum;
private RPCProcessor tvmServerWorker;
@Override
public int onStartCommand(Intent intent, int flags, int startId) {
synchronized(this) {
System.err.println("start command intent");
// use an alternate kill to prevent android from recycling the
// process
if (intent.getBooleanExtra("kill", false)) {
System.err.println("rpc service received kill...");
System.exit(0);
}
this.host = intent.getStringExtra("host");
this.port = intent.getIntExtra("port", 9090);
this.key = intent.getStringExtra("key");
System.err.println("got the following: " + this.host + ", " + this.port + ", " + this.key);
System.err.println("intent num: " + this.intentNum);
if (tvmServerWorker == null) {
System.err.println("service created worker...");
tvmServerWorker = new RPCProcessor();
tvmServerWorker.setDaemon(true);
tvmServerWorker.start();
tvmServerWorker.connect(this.host, this.port, this.key);
}
else if (tvmServerWorker.timedOut(System.currentTimeMillis())) {
System.err.println("rpc service timed out, killing self...");
System.exit(0);
}
this.intentNum++;
}
// do not restart unless watchdog/app expliciltly does so
return START_NOT_STICKY;
}
@Override
public IBinder onBind(Intent intent) {
System.err.println("rpc service got onBind, doing nothing...");
return null;
}
@Override
public void onCreate() {
System.err.println("rpc service onCreate...");
}
@Override
public void onDestroy() {
tvmServerWorker.disconnect();
System.err.println("rpc service onDestroy...");
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.tvmrpc;
import android.content.Context;
import android.content.Intent;
/**
* Watchdog for RPCService
*/
class RPCWatchdog extends Thread {
public static final int WATCHDOG_POLL_INTERVAL = 5000;
private String host;
private int port;
private String key;
private Context context;
private boolean done = false;
public RPCWatchdog(String host, int port, String key, Context context) {
super();
this.host = host;
this.port = port;
this.key = key;
this.context = context;
}
/**
* Polling loop to check on RPCService status
*/
@Override public void run() {
try {
while (true) {
synchronized (this) {
if (done) {
System.err.println("watchdog done, returning...");
return;
}
else {
System.err.println("polling rpc service...");
System.err.println("sending rpc service intent...");
Intent intent = new Intent(context, RPCService.class);
intent.putExtra("host", host);
intent.putExtra("port", port);
intent.putExtra("key", key);
// will implicilty restart the service if it died
context.startService(intent);
}
}
Thread.sleep(WATCHDOG_POLL_INTERVAL);
}
} catch (InterruptedException e) {
}
}
/**
* Disconnect from the proxy server.
*/
synchronized void disconnect() {
// kill service
System.err.println("watchdog disconnect call...");
System.err.println("stopping rpc service...");
done = true;
Intent intent = new Intent(context, RPCService.class);
intent.putExtra("kill", true);
context.startService(intent);
}
}
...@@ -12,15 +12,14 @@ include $(config) ...@@ -12,15 +12,14 @@ include $(config)
# 1) armeabi is deprecated in NDK r16 and removed in r17 # 1) armeabi is deprecated in NDK r16 and removed in r17
# 2) vulkan is not supported in armeabi # 2) vulkan is not supported in armeabi
APP_ABI := armeabi-v7a arm64-v8a x86 x86_64 mips APP_ABI := armeabi-v7a arm64-v8a x86 x86_64 mips
APP_STL := c++_static APP_STL := c++_static
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti
ifeq ($(USE_OPENCL), 1) ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif endif
ifeq ($(USE_VULKAN), 1) ifeq ($(USE_VULKAN), 1)
APP_CPPFLAGS += -DTVM_VULKAN_RUNTIME=1 APP_CPPFLAGS += -DTVM_VULKAN_RUNTIME=1
APP_LDFLAGS += -lvulkan APP_LDFLAGS += -lvulkan
endif endif
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
APP_ABI = all APP_ABI = all
APP_PLATFORM = android-17 APP_PLATFORM = android-24
# whether enable OpenCL during compile # whether enable OpenCL during compile
USE_OPENCL = 0 USE_OPENCL = 0
......
...@@ -24,4 +24,3 @@ ...@@ -24,4 +24,3 @@
<include layout="@layout/content_main"/> <include layout="@layout/content_main"/>
</android.support.design.widget.CoordinatorLayout> </android.support.design.widget.CoordinatorLayout>
<?xml version="1.0" encoding="utf-8"?>
<android.support.design.widget.CoordinatorLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="ml.dmlc.tvm.tvmrpc.RPCActivity">
<android.support.design.widget.AppBarLayout
android:layout_height="wrap_content"
android:layout_width="match_parent"
android:theme="@style/AppTheme.AppBarOverlay">
<android.support.v7.widget.Toolbar
android:id="@+id/toolbar"
android:layout_width="match_parent"
android:layout_height="?attr/actionBarSize"
android:background="?attr/colorPrimary"
app:popupTheme="@style/AppTheme.PopupOverlay" />
</android.support.design.widget.AppBarLayout>
<include layout="@layout/content_rpc"/>
</android.support.design.widget.CoordinatorLayout>
...@@ -64,9 +64,9 @@ ...@@ -64,9 +64,9 @@
<TextView <TextView
android:layout_width="wrap_content" android:layout_width="wrap_content"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:text="@string/label_connect"/> android:text="@string/label_persistent"/>
<Switch <Switch
android:id="@+id/switch_connect" android:id="@+id/switch_persistent"
android:layout_width="wrap_content" android:layout_width="wrap_content"
android:layout_height="wrap_content" android:layout_height="wrap_content"
android:switchMinWidth="55dp" android:switchMinWidth="55dp"
...@@ -76,4 +76,15 @@ ...@@ -76,4 +76,15 @@
android:textOn="@string/switch_on" /> android:textOn="@string/switch_on" />
</LinearLayout> </LinearLayout>
<LinearLayout
android:orientation="horizontal"
android:layout_width="fill_parent"
android:layout_height="wrap_content">
<Button
android:id="@+id/button_start_rpc"
android:layout_height="wrap_content"
android:layout_width="wrap_content"
android:text="@string/start_rpc" />
</LinearLayout>
</LinearLayout> </LinearLayout>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:orientation="vertical"
android:layout_width="fill_parent"
android:layout_height="wrap_content"
app:layout_behavior="@string/appbar_scrolling_view_behavior"
tools:showIn="@layout/activity_rpc">
<Button
android:id="@+id/button_stop_rpc"
android:layout_height="wrap_content"
android:layout_width="wrap_content"
android:text="@string/stop_rpc" />
</LinearLayout>
<resources> <resources>
<string name="app_name">TVM RPC</string> <string name="app_name">TVM RPC</string>
<string name="rpc_name">RPC</string>
<string name="input_address">Enter the proxy server address</string> <string name="input_address">Enter the tracker server address</string>
<string name="input_port">Enter the proxy server port</string> <string name="input_port">Enter the tracker server port</string>
<string name="input_key">Enter the app connection key</string> <string name="input_key">Enter the app connection key</string>
<string name="label_address">Address</string> <string name="label_address">Address</string>
<string name="label_port">Port</string> <string name="label_port">Port</string>
<string name="label_key">Key</string> <string name="label_key">Key</string>
<string name="label_connect">Connect to Proxy</string> <string name="label_persistent">Keep RPC Alive</string>
<string name="switch_on">Connected</string> <string name="switch_on">Enabled</string>
<string name="switch_off">Disconnected</string> <string name="switch_off">Disabled</string>
<string name="start_rpc">Start RPC</string>
<string name="stop_rpc">Stop RPC</string>
</resources> </resources>
...@@ -11,8 +11,8 @@ from tvm.contrib import util, ndk ...@@ -11,8 +11,8 @@ from tvm.contrib import util, ndk
import numpy as np import numpy as np
# Set to be address of tvm proxy. # Set to be address of tvm proxy.
proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"] tracker_host = os.environ["TVM_TRACKER_HOST"]
proxy_port = 9090 tracker_port = int(os.environ["TVM_TRACKER_PORT"])
key = "android" key = "android"
# Change target configuration. # Change target configuration.
...@@ -33,7 +33,7 @@ def test_rpc_module(): ...@@ -33,7 +33,7 @@ def test_rpc_module():
# Build the dynamic lib. # Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target # If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "opencl", target_host=target, name="myadd") f = tvm.build(s, [A, B], "opencl", target_host=target, name="myadd")
path_dso1 = temp.relpath("dev_lib.so") path_dso1 = temp.relpath("dev_lib2.so")
f.export_library(path_dso1, ndk.create_shared) f.export_library(path_dso1, ndk.create_shared)
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
...@@ -45,29 +45,31 @@ def test_rpc_module(): ...@@ -45,29 +45,31 @@ def test_rpc_module():
path_dso2 = temp.relpath("cpu_lib.so") path_dso2 = temp.relpath("cpu_lib.so")
f.export_library(path_dso2, ndk.create_shared) f.export_library(path_dso2, ndk.create_shared)
# connect to the proxy tracker = rpc.connect_tracker(tracker_host, tracker_port)
remote = rpc.connect(proxy_host, proxy_port, key=key) remote = tracker.request(key, priority=0,
session_timeout=60)
print('Run GPU test ...') print('Run CPU test ...')
ctx = remote.cl(0) ctx = remote.cpu(0)
remote.upload(path_dso1) remote.upload(path_dso2)
f1 = remote.load_module("dev_lib.so") f2 = remote.load_module("cpu_lib.so")
a_np = np.random.uniform(size=1024).astype(A.dtype) a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, ctx, number=10) time_f = f2.time_evaluator(f2.entry_name, ctx, number=10)
cost = time_f(a, b).mean cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
print('Run CPU test ...')
ctx = remote.cpu(0) print('Run GPU test ...')
remote.upload(path_dso2) ctx = remote.cl(0)
f2 = remote.load_module("cpu_lib.so") remote.upload(path_dso1)
f1 = remote.load_module("dev_lib2.so")
a_np = np.random.uniform(size=1024).astype(A.dtype) a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f2.time_evaluator(f2.entry_name, ctx, number=10) time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
cost = time_f(a, b).mean cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.BindException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketTimeoutException;
/**
* Server processor with tracker connection (based on standalone).
* This RPC Server registers itself with an RPC Tracker for a specific queue
* (using its device key) and listens for incoming requests.
*/
public class ConnectTrackerServerProcessor implements ServerProcessor {
private ServerSocket server;
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private final String trackerHost;
private final int trackerPort;
// device key
private final String key;
// device key plus randomly generated key (per-session)
private final String matchKey;
private int serverPort = 5001;
public static final int MAX_SERVER_PORT = 5555;
// time to wait before aborting tracker connection (ms)
public static final int TRACKER_TIMEOUT = 6000;
// time to wait before retrying tracker connection (ms)
public static final int RETRY_PERIOD = TRACKER_TIMEOUT;
// time to wait for a connection before refreshing tracker connection (ms)
public static final int STALE_TRACKER_TIMEOUT = 300000;
// time to wait if no timeout value is specified (seconds)
public static final int HARD_TIMEOUT_DEFAULT = 300;
private RPCWatchdog watchdog;
private Socket trackerSocket;
/**
* Construct tracker server processor.
* @param trackerHost Tracker host.
* @param trackerPort Tracker port.
* @param key Device key.
* @param sockFdGetter Method to get file descriptor from Java socket.
*/
public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key,
SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException {
while (true) {
try {
this.server = new ServerSocket(serverPort);
server.setSoTimeout(STALE_TRACKER_TIMEOUT);
break;
} catch (BindException e) {
System.err.println(serverPort);
System.err.println(e);
serverPort++;
if (serverPort > MAX_SERVER_PORT) {
throw e;
}
}
}
System.err.println("using port: " + serverPort);
this.socketFileDescriptorGetter = sockFdGetter;
this.trackerHost = trackerHost;
this.trackerPort = trackerPort;
this.key = key;
this.matchKey = key + ":" + Math.random();
this.watchdog = watchdog;
}
public String getMatchKey() {
return matchKey;
}
@Override public void terminate() {
try {
server.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override public void run() {
String recvKey = null;
try {
trackerSocket = connectToTracker();
// open a socket and handshake with tracker
register();
Socket socket = null;
InputStream in = null;
OutputStream out = null;
while (true) {
try {
System.err.println("waiting for requests...");
// wait for client request
socket = server.accept();
in = socket.getInputStream();
out = socket.getOutputStream();
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) {
out.write(Utils.toBytes(RPC.RPC_CODE_MISMATCH));
System.err.println("incorrect RPC magic");
Utils.closeQuietly(socket);
continue;
}
recvKey = Utils.recvString(in);
System.err.println("matchKey:" + matchKey);
System.err.println("key: " + recvKey);
// incorrect key
if (recvKey.indexOf(matchKey) == -1) {
out.write(Utils.toBytes(RPC.RPC_CODE_MISMATCH));
System.err.println("key mismatch, expected: " + matchKey + " got: " + recvKey);
Utils.closeQuietly(socket);
continue;
}
// successfully got client request and completed handshake with client
break;
} catch (SocketTimeoutException e) {
System.err.println("no incoming connections, refreshing...");
// need to reregister, if the tracker died we should see a socked closed exception
if (!needRefreshKey()) {
System.err.println("reregistering...");
register();
}
}
}
int timeout = HARD_TIMEOUT_DEFAULT;
int timeoutArgIndex = recvKey.indexOf(RPC.TIMEOUT_ARG);
if (timeoutArgIndex != -1) {
timeout = Integer.parseInt(recvKey.substring(timeoutArgIndex + RPC.TIMEOUT_ARG.length()));
}
System.err.println("alloted timeout: " + timeout);
if (!recvKey.startsWith("client:")) {
System.err.println("recv key mismatch...");
out.write(Utils.toBytes(RPC.RPC_CODE_MISMATCH));
} else {
out.write(Utils.toBytes(RPC.RPC_MAGIC));
// send server key to the client
Utils.sendString(out, recvKey);
}
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
// received timeout in seconds
watchdog.startTimeout(timeout * 1000);
final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) {
new NativeServerLoop(sockFd).run();
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
}
Utils.closeQuietly(socket);
} catch (ConnectException e) {
// if the tracker connection failed, wait a bit before retrying
try {
Thread.sleep(RETRY_PERIOD);
} catch (InterruptedException e_) {
System.err.println("interrupted before retrying to connect to tracker...");
}
} catch (Throwable e) {
e.printStackTrace();
} finally {
try {
if (trackerSocket != null) {
trackerSocket.close();
}
server.close();
} catch (Throwable e) {
e.printStackTrace();
}
}
}
private Socket connectToTracker() throws IOException {
trackerSocket = new Socket();
SocketAddress address = new InetSocketAddress(trackerHost, trackerPort);
trackerSocket.connect(address, TRACKER_TIMEOUT);
InputStream trackerIn = trackerSocket.getInputStream();
OutputStream trackerOut = trackerSocket.getOutputStream();
trackerOut.write(Utils.toBytes(RPC.RPC_TRACKER_MAGIC));
int trackerMagic = Utils.wrapBytes(Utils.recvAll(trackerIn, 4)).getInt();
if (trackerMagic != RPC.RPC_TRACKER_MAGIC) {
throw new SocketException("failed to connect to tracker (WRONG MAGIC)");
}
return trackerSocket;
}
/*
* Register the RPC Server with the RPC Tracker.
*/
private void register() throws IOException {
InputStream trackerIn = trackerSocket.getInputStream();
OutputStream trackerOut = trackerSocket.getOutputStream();
// send a JSON with PUT, device key, RPC server port, and the randomly
// generated key
String putJSON = generatePut(RPC.TrackerCode.PUT, key, serverPort, matchKey);
Utils.sendString(trackerOut, putJSON);
int recvCode = Integer.parseInt(Utils.recvString(trackerIn));
if (recvCode != RPC.TrackerCode.SUCCESS) {
throw new SocketException("failed to register with tracker (not SUCCESS)");
}
System.err.println("registered with tracker...");
}
/*
* Check if the RPC Tracker has our key.
*/
private boolean needRefreshKey() throws IOException {
InputStream trackerIn = trackerSocket.getInputStream();
OutputStream trackerOut = trackerSocket.getOutputStream();
String getJSON = generateGetPendingMatchKeys(RPC.TrackerCode.GET_PENDING_MATCHKEYS);
Utils.sendString(trackerOut, getJSON);
String recvJSON = Utils.recvString(trackerIn);
System.err.println("pending matchkeys: " + recvJSON);
// fairly expensive string operation...
if (recvJSON.indexOf(matchKey) != -1 ) {
return true;
}
return false;
}
// handcrafted JSON
private String generatePut(int code, String key, int port, String matchKey) {
return "[" + code + ", " + "\"" + key + "\"" + ", " + "[" + port + ", "
+ "\"" + matchKey + "\"" + "]" + ", " + "null" + "]";
}
// handcrafted JSON
private String generateGetPendingMatchKeys(int code) {
return "[" + code + "]";
}
}
...@@ -42,7 +42,9 @@ public class NativeServerLoop implements Runnable { ...@@ -42,7 +42,9 @@ public class NativeServerLoop implements Runnable {
File tempDir = null; File tempDir = null;
try { try {
tempDir = serverEnv(); tempDir = serverEnv();
System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
System.err.println("done server loop...");
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); e.printStackTrace();
} finally { } finally {
......
...@@ -23,9 +23,19 @@ import java.util.HashMap; ...@@ -23,9 +23,19 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
public class RPC { public class RPC {
public static final int RPC_TRACKER_MAGIC = 0x2f271;
public static final int RPC_MAGIC = 0xff271; public static final int RPC_MAGIC = 0xff271;
public static final int RPC_CODE_MISMATCH = RPC_MAGIC + 2;
public static final int RPC_SESS_MASK = 128; public static final int RPC_SESS_MASK = 128;
public static final String TIMEOUT_ARG = "-timeout=";
public class TrackerCode {
public static final int PUT = 3;
public static final int GET_PENDING_MATCHKEYS = 7;
public static final int SUCCESS = 0;
}
private static ThreadLocal<Map<String, Function>> apiFuncs private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() { = new ThreadLocal<Map<String, Function>>() {
@Override @Override
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.rpc;
/**
* Watchdog for RPC.
*/
public class RPCWatchdog extends Thread {
private int timeout = 0;
private boolean started = false;
public RPCWatchdog() {
super();
}
/**
* Start a timeout with watchdog (must be called before finishTimeout).
* @param timeout watchdog timeout in ms.
*/
public synchronized void startTimeout(int timeout) {
this.timeout = timeout;
started = true;
this.notify();
}
/**
* Finish a timeout with watchdog (must be called after startTimeout).
*/
public synchronized void finishTimeout() {
started = false;
this.notify();
}
/**
* Wait and kill RPC if timeout is exceeded.
*/
@Override public void run() {
while (true) {
// timeout not started
synchronized (this) {
while (!started) {
try {
this.wait();
} catch (InterruptedException e) {
System.err.println("watchdog interrupted...");
}
}
}
synchronized (this) {
while (started) {
try {
System.err.println("waiting for timeout: " + timeout);
this.wait(timeout);
if (!started) {
System.err.println("watchdog woken up, ok...");
} else {
System.err.println("watchdog woke up!");
System.err.println("terminating...");
System.exit(0);
}
} catch (InterruptedException e) {
System.err.println("watchdog interrupted...");
}
}
}
}
}
}
...@@ -19,6 +19,7 @@ package ml.dmlc.tvm.rpc; ...@@ -19,6 +19,7 @@ package ml.dmlc.tvm.rpc;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket; import java.net.Socket;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
...@@ -76,4 +77,16 @@ class Utils { ...@@ -76,4 +77,16 @@ class Utils {
} }
return builder.toString(); return builder.toString();
} }
public static String recvString(InputStream in) throws IOException {
String recvString = null;
int len = wrapBytes(Utils.recvAll(in, 4)).getInt();
recvString = decodeToStr(Utils.recvAll(in, len));
return recvString;
}
public static void sendString(OutputStream out, String string) throws IOException {
out.write(toBytes(string.length()));
out.write(toBytes(string));
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment