Commit 0d673a9d by eqy Committed by Tianqi Chen

[RPC] android process isolation/watchdog (#1387)

parent 8a90f877
...@@ -2,11 +2,14 @@ ...@@ -2,11 +2,14 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android" <manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="ml.dmlc.tvm.tvmrpc" > package="ml.dmlc.tvm.tvmrpc" >
<uses-permission android:name="android.permission.INTERNET" />
<application <application
android:allowBackup="true" android:allowBackup="true"
android:label="@string/app_name" android:label="@string/app_name"
android:supportsRtl="true" android:supportsRtl="true"
android:theme="@style/AppTheme" > android:theme="@style/AppTheme"
android:icon="@mipmap/ic_launcher" >
<activity <activity
android:name=".MainActivity" android:name=".MainActivity"
android:label="@string/app_name" android:label="@string/app_name"
...@@ -17,8 +20,9 @@ ...@@ -17,8 +20,9 @@
<category android:name="android.intent.category.LAUNCHER" /> <category android:name="android.intent.category.LAUNCHER" />
</intent-filter> </intent-filter>
</activity> </activity>
<service android:name=".RPCService"
android:process=":RPCServiceProcess"
android:permission="android.permission.BIND_JOB_SERVICE" />
</application> </application>
<uses-permission android:name="android.permission.INTERNET" /> </manifest>
</manifest>
\ No newline at end of file
...@@ -25,30 +25,18 @@ import android.content.SharedPreferences; ...@@ -25,30 +25,18 @@ import android.content.SharedPreferences;
import android.os.Bundle; import android.os.Bundle;
import android.os.Handler; import android.os.Handler;
import android.os.Message; import android.os.Message;
import android.support.v7.app.AppCompatActivity; import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar; import android.support.v7.widget.Toolbar;
import android.widget.CompoundButton; import android.widget.CompoundButton;
import android.widget.EditText; import android.widget.EditText;
import android.widget.Switch; import android.widget.Switch;
import android.content.Intent;
public class MainActivity extends AppCompatActivity { public class MainActivity extends AppCompatActivity {
static final int MSG_RPC_ERROR = 0;
static final String MSG_RPC_ERROR_DATA_KEY = "msg_rpc_error_data_key"; private RPCWatchdog watchdog;
private RPCProcessor tvmServerWorker;
@SuppressLint("HandlerLeak")
private final Handler rpcHandler = new Handler() {
@Override
public void dispatchMessage(Message msg) {
Switch switchConnect = findViewById(R.id.switch_connect);
if (msg.what == MSG_RPC_ERROR && switchConnect.isChecked()) {
// switch off and show alert dialog.
switchConnect.setChecked(false);
String msgBody = msg.getData().getString(MSG_RPC_ERROR_DATA_KEY);
showDialog("Error", msgBody);
}
}
};
private void showDialog(String title, String msg) { private void showDialog(String title, String msg) {
AlertDialog.Builder builder = new AlertDialog.Builder(this); AlertDialog.Builder builder = new AlertDialog.Builder(this);
...@@ -71,10 +59,6 @@ public class MainActivity extends AppCompatActivity { ...@@ -71,10 +59,6 @@ public class MainActivity extends AppCompatActivity {
Toolbar toolbar = findViewById(R.id.toolbar); Toolbar toolbar = findViewById(R.id.toolbar);
setSupportActionBar(toolbar); setSupportActionBar(toolbar);
tvmServerWorker = new RPCProcessor(rpcHandler);
tvmServerWorker.setDaemon(true);
tvmServerWorker.start();
Switch switchConnect = findViewById(R.id.switch_connect); Switch switchConnect = findViewById(R.id.switch_connect);
switchConnect.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { switchConnect.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
@Override @Override
...@@ -88,25 +72,33 @@ public class MainActivity extends AppCompatActivity { ...@@ -88,25 +72,33 @@ public class MainActivity extends AppCompatActivity {
} }
} }
}); });
enableInputView(true); enableInputView(true);
} }
@Override @Override
protected void onDestroy() { protected void onDestroy() {
super.onDestroy(); super.onDestroy();
tvmServerWorker.disconnect(); if (watchdog != null) {
watchdog.disconnect();
watchdog = null;
}
} }
private void connectProxy() { private void connectProxy() {
EditText edProxyAddress = findViewById(R.id.input_address); EditText edProxyAddress = findViewById(R.id.input_address);
EditText edProxyPort = findViewById(R.id.input_port); EditText edProxyPort = findViewById(R.id.input_port);
EditText edAppKey = findViewById(R.id.input_key); EditText edAppKey = findViewById(R.id.input_key);
final String proxyHost = edProxyAddress.getText().toString(); final String proxyHost = edProxyAddress.getText().toString();
final int proxyPort = Integer.parseInt(edProxyPort.getText().toString()); final int proxyPort = Integer.parseInt(edProxyPort.getText().toString());
final String key = edAppKey.getText().toString(); final String key = edAppKey.getText().toString();
tvmServerWorker.connect(proxyHost, proxyPort, key); System.err.println("creating watchdog thread...");
watchdog = new RPCWatchdog(proxyHost, proxyPort, key, this);
System.err.println("starting watchdog thread...");
watchdog.start();
SharedPreferences pref = getApplicationContext().getSharedPreferences("RPCProxyPreference", Context.MODE_PRIVATE); SharedPreferences pref = getApplicationContext().getSharedPreferences("RPCProxyPreference", Context.MODE_PRIVATE);
SharedPreferences.Editor editor = pref.edit(); SharedPreferences.Editor editor = pref.edit();
...@@ -117,8 +109,10 @@ public class MainActivity extends AppCompatActivity { ...@@ -117,8 +109,10 @@ public class MainActivity extends AppCompatActivity {
} }
private void disconnect() { private void disconnect() {
tvmServerWorker.disconnect(); if (watchdog != null) {
System.err.println("Disconnected."); watchdog.disconnect();
watchdog = null;
}
} }
private void enableInputView(boolean enable) { private void enableInputView(boolean enable) {
......
...@@ -34,10 +34,11 @@ class RPCProcessor extends Thread { ...@@ -34,10 +34,11 @@ class RPCProcessor extends Thread {
private String host; private String host;
private int port; private int port;
private String key; private String key;
private boolean running = false; private boolean running = false;
private long startTime;
private ConnectProxyServerProcessor currProcessor; private ConnectProxyServerProcessor currProcessor;
private final Handler uiHandler; private boolean kill = false;
public static final int SESSION_TIMEOUT = 30000;
static final SocketFileDescriptorGetter socketFdGetter static final SocketFileDescriptorGetter socketFdGetter
= new SocketFileDescriptorGetter() { = new SocketFileDescriptorGetter() {
...@@ -46,9 +47,18 @@ class RPCProcessor extends Thread { ...@@ -46,9 +47,18 @@ class RPCProcessor extends Thread {
return ParcelFileDescriptor.fromSocket(socket).getFd(); return ParcelFileDescriptor.fromSocket(socket).getFd();
} }
}; };
// callback to initialize the start time of an rpc session
class setTimeCallback implements Runnable {
private RPCProcessor rPCProcessor;
public setTimeCallback(RPCProcessor rPCProcessor) {
this.rPCProcessor = rPCProcessor;
}
RPCProcessor(Handler uiHandler) { @Override
this.uiHandler = uiHandler; public void run() {
rPCProcessor.setStartTime();
}
} }
@Override public void run() { @Override public void run() {
...@@ -61,24 +71,52 @@ class RPCProcessor extends Thread { ...@@ -61,24 +71,52 @@ class RPCProcessor extends Thread {
} catch (InterruptedException e) { } catch (InterruptedException e) {
} }
} }
currProcessor = new ConnectProxyServerProcessor(host, port, key, socketFdGetter); // if kill, we do nothing and wait for app restart
} // to prevent race where timedOut was reported but restart has not
try { // happened yet
currProcessor.run(); if (kill) {
} catch (Throwable e) { System.err.println("waiting for restart...");
disconnect(); currProcessor = null;
// turn connect switch off. }
Message message = new Message(); else {
message.what = MainActivity.MSG_RPC_ERROR; startTime = 0;
Bundle bundle = new Bundle(); currProcessor = new ConnectProxyServerProcessor(host, port, key, socketFdGetter);
bundle.putString(MainActivity.MSG_RPC_ERROR_DATA_KEY, e.getMessage()); currProcessor.setStartTimeCallback(new setTimeCallback(this));
message.setData(bundle); }
uiHandler.sendMessage(message);
} }
if (currProcessor != null)
currProcessor.run();
} }
} }
/** /**
* check if the current RPCProcessor has timed out while in a session
*/
synchronized boolean timedOut(long curTime) {
if (startTime == 0) {
return false;
}
else if ((curTime - startTime) > SESSION_TIMEOUT) {
System.err.println("set kill flag...");
kill = true;
return true;
}
return false;
}
/**
* set the start time of the current RPC session (used in callback)
*/
synchronized void setStartTime() {
startTime = System.currentTimeMillis();
System.err.println("start time set to: " + startTime);
}
synchronized long getStartTime() {
return startTime;
}
/**
* Disconnect from the proxy server. * Disconnect from the proxy server.
*/ */
synchronized void disconnect() { synchronized void disconnect() {
......
package ml.dmlc.tvm.tvmrpc;
import android.app.Service;
import android.os.IBinder;
import android.content.Intent;
public class RPCService extends Service {
private String host;
private int port;
private String key;
private int intentNum;
private RPCProcessor tvmServerWorker;
@Override
public int onStartCommand(Intent intent, int flags, int startId) {
synchronized(this) {
System.err.println("start command intent");
// use an alternate kill to prevent android from recycling the
// process
if (intent.getBooleanExtra("kill", false)) {
System.err.println("rpc service received kill...");
System.exit(0);
}
this.host = intent.getStringExtra("host");
this.port = intent.getIntExtra("port", 9090);
this.key = intent.getStringExtra("key");
System.err.println("got the following: " + this.host + ", " + this.port + ", " + this.key);
System.err.println("intent num: " + this.intentNum);
if (tvmServerWorker == null) {
System.err.println("service created worker...");
tvmServerWorker = new RPCProcessor();
tvmServerWorker.setDaemon(true);
tvmServerWorker.start();
tvmServerWorker.connect(this.host, this.port, this.key);
}
else if (tvmServerWorker.timedOut(System.currentTimeMillis())) {
System.err.println("rpc service timed out, killing self...");
System.exit(0);
}
this.intentNum++;
}
// do not restart unless watchdog/app expliciltly does so
return START_NOT_STICKY;
}
@Override
public IBinder onBind(Intent intent) {
System.err.println("rpc service got onBind, doing nothing...");
return null;
}
@Override
public void onCreate() {
System.err.println("rpc service onCreate...");
}
@Override
public void onDestroy() {
tvmServerWorker.disconnect();
System.err.println("rpc service onDestroy...");
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm.tvmrpc;
import android.content.Context;
import android.content.Intent;
/**
* Watchdog for RPCService
*/
class RPCWatchdog extends Thread {
public static final int WATCHDOG_POLL_INTERVAL = 5000;
private String host;
private int port;
private String key;
private Context context;
private boolean done = false;
public RPCWatchdog(String host, int port, String key, Context context) {
super();
this.host = host;
this.port = port;
this.key = key;
this.context = context;
}
/**
* Polling loop to check on RPCService status
*/
@Override public void run() {
try {
while (true) {
synchronized (this) {
if (done) {
System.err.println("watchdog done, returning...");
return;
}
else {
System.err.println("polling rpc service...");
System.err.println("sending rpc service intent...");
Intent intent = new Intent(context, RPCService.class);
intent.putExtra("host", host);
intent.putExtra("port", port);
intent.putExtra("key", key);
// will implicilty restart the service if it died
context.startService(intent);
}
}
Thread.sleep(WATCHDOG_POLL_INTERVAL);
}
} catch (InterruptedException e) {
}
}
/**
* Disconnect from the proxy server.
*/
synchronized void disconnect() {
// kill service
System.err.println("watchdog disconnect call...");
System.err.println("stopping rpc service...");
done = true;
Intent intent = new Intent(context, RPCService.class);
intent.putExtra("kill", true);
context.startService(intent);
}
}
...@@ -33,6 +33,7 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -33,6 +33,7 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
private final SocketFileDescriptorGetter socketFileDescriptorGetter; private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private volatile Socket currSocket = new Socket(); private volatile Socket currSocket = new Socket();
private Runnable callback;
/** /**
* Construct proxy server processor. * Construct proxy server processor.
...@@ -48,6 +49,15 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -48,6 +49,15 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
this.key = "server:" + key; this.key = "server:" + key;
socketFileDescriptorGetter = sockFdGetter; socketFileDescriptorGetter = sockFdGetter;
} }
/**
* Set a callback when a connection is received e.g., to record the time for a
* watchdog.
* @param callback Runnable object.
*/
public void setStartTimeCallback(Runnable callback) {
this.callback = callback;
}
/** /**
* Close the socket. * Close the socket.
...@@ -78,7 +88,9 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -78,7 +88,9 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
int keylen = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt(); int keylen = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
String remoteKey = Utils.decodeToStr(Utils.recvAll(in, keylen)); String remoteKey = Utils.decodeToStr(Utils.recvAll(in, keylen));
System.err.println("RPCProxy connected to " + address); System.err.println("RPCProxy connected to " + address);
if (callback != null) {
callback.run();
}
final int sockFd = socketFileDescriptorGetter.get(currSocket); final int sockFd = socketFileDescriptorGetter.get(currSocket);
if (sockFd != -1) { if (sockFd != -1) {
new NativeServerLoop(sockFd).run(); new NativeServerLoop(sockFd).run();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment