diff --git a/apps/android_rpc/README.md b/apps/android_rpc/README.md index 00cb597cec287258dff0c15871abbacd529f1033..cedd9ab0ac0c34839022e0fc6282d03ed455dc96 100644 --- a/apps/android_rpc/README.md +++ b/apps/android_rpc/README.md @@ -51,15 +51,15 @@ Here's a piece of example for `config.mk`. ```makefile APP_ABI = arm64-v8a - + APP_PLATFORM = android-17 - + # whether enable OpenCL during compile USE_OPENCL = 1 - + # the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc ADD_C_INCLUDES = /opt/adrenosdk-osx/Development/Inc - + # the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so ADD_LDLIBS = libOpenCL.so ``` @@ -85,13 +85,14 @@ If everything goes well, you will find compile tools in `/opt/android-toolchain- ### Cross Compile and Upload to the Android Device -First start a proxy server using `python -m tvm.exec.rpc_proxy` and make your Android device connect to this proxy server via TVM RPC application. +First start an RPC tracker using `python -m tvm.exec.rpc_tracker --port [PORT]` and make your Android device connect to this RPC tracker via TVM RPC application. Then checkout [android\_rpc/tests/android\_rpc\_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py) and run, ```bash -# Specify the proxy host -export TVM_ANDROID_RPC_PROXY_HOST=0.0.0.0 +# Specify the RPC tracker +export TVM_TRACKER_HOST=0.0.0.0 +export TVM_TRACKER_PORT=[PORT] # Specify the standalone Android C++ compiler export TVM_NDK_CC=/opt/android-toolchain-arm64/bin/aarch64-linux-android-g++ python android_rpc_test.py diff --git a/apps/android_rpc/app/build.gradle b/apps/android_rpc/app/build.gradle index 97364da5cd87b7ff44f02784b67b07b03b992c48..a91455fc54772d2e5f1433a8cc2602b4d6ad93c6 100644 --- a/apps/android_rpc/app/build.gradle +++ b/apps/android_rpc/app/build.gradle @@ -13,7 +13,7 @@ android { buildToolsVersion "26.0.1" defaultConfig { applicationId "ml.dmlc.tvm.tvmrpc" - minSdkVersion 17 + minSdkVersion 24 targetSdkVersion 26 versionCode 1 versionName "1.0" diff --git a/apps/android_rpc/app/src/main/AndroidManifest.xml b/apps/android_rpc/app/src/main/AndroidManifest.xml index d8385e3b15e2508051c74d806156a3bb49eb1926..2dbc06ece6e36656637e83dea4519e69d644fced 100644 --- a/apps/android_rpc/app/src/main/AndroidManifest.xml +++ b/apps/android_rpc/app/src/main/AndroidManifest.xml @@ -20,9 +20,16 @@ <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity> - <service android:name=".RPCService" - android:process=":RPCServiceProcess" - android:permission="android.permission.BIND_JOB_SERVICE" /> + <service android:name=".RPCWatchdogService" + android:process=":RPCWatchdogServiceProcess" + android:permission="android.permission.BIND_JOB_SERVICE" /> + <activity + android:name=".RPCActivity" + android:process=":RPCProcess" + android:label="@string/rpc_name" + android:theme="@style/AppTheme.NoActionBar" + android:screenOrientation="portrait"> + </activity> </application> </manifest> diff --git a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/MainActivity.java b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/MainActivity.java index 0fca08a2e463ec1da07acd98738bd537dcd0398f..d80008bbe25811ed1b5c1c191ad79784c7e7517a 100644 --- a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/MainActivity.java +++ b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/MainActivity.java @@ -31,12 +31,18 @@ import android.support.v7.widget.Toolbar; import android.widget.CompoundButton; import android.widget.EditText; import android.widget.Switch; +import android.widget.Button; +import android.view.View; import android.content.Intent; +import android.app.NotificationChannel; +import android.app.NotificationManager; public class MainActivity extends AppCompatActivity { + private boolean skipRelaunch = true; + // wait time before automatic restart of RPC Activity + public static final int HANDLER_RESTART_DELAY = 5000; - private RPCWatchdog watchdog; private void showDialog(String title, String msg) { AlertDialog.Builder builder = new AlertDialog.Builder(this); @@ -52,73 +58,107 @@ public class MainActivity extends AppCompatActivity { builder.create().show(); } + public Intent updateRPCPrefs() { + System.err.println("updating preferences..."); + EditText edProxyAddress = findViewById(R.id.input_address); + EditText edProxyPort = findViewById(R.id.input_port); + EditText edAppKey = findViewById(R.id.input_key); + Switch inputSwitch = findViewById(R.id.switch_persistent); + + final String proxyHost = edProxyAddress.getText().toString(); + final int proxyPort = Integer.parseInt(edProxyPort.getText().toString()); + final String key = edAppKey.getText().toString(); + final boolean isChecked = inputSwitch.isChecked(); + + SharedPreferences pref = getApplicationContext().getSharedPreferences("RPCProxyPreference", Context.MODE_PRIVATE); + SharedPreferences.Editor editor = pref.edit(); + editor.putString("input_address", proxyHost); + editor.putString("input_port", edProxyPort.getText().toString()); + editor.putString("input_key", key); + editor.putBoolean("input_switch", isChecked); + editor.commit(); + + Intent intent = new Intent(this, RPCActivity.class); + intent.putExtra("host", proxyHost); + intent.putExtra("port", proxyPort); + intent.putExtra("key", key); + return intent; + } + + private void setupRelaunch() { + final Context context = this; + final Switch switchPersistent = findViewById(R.id.switch_persistent); + final Runnable rPCStarter = new Runnable() { + public void run() { + if (switchPersistent.isChecked()) { + System.err.println("relaunching RPC activity in 5s..."); + Intent intent = ((MainActivity) context).updateRPCPrefs(); + startActivity(intent); + } + } + }; + Handler handler = new Handler(); + handler.postDelayed(rPCStarter, HANDLER_RESTART_DELAY); + } + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Toolbar toolbar = findViewById(R.id.toolbar); setSupportActionBar(toolbar); + final Context context = this; - Switch switchConnect = findViewById(R.id.switch_connect); - switchConnect.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { + Switch switchPersistent = findViewById(R.id.switch_persistent); + switchPersistent.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { @Override public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { if (isChecked) { - enableInputView(false); - connectProxy(); + System.err.println("automatic RPC restart enabled..."); + updateRPCPrefs(); } else { - disconnect(); - enableInputView(true); + System.err.println("automatic RPC restart disabled..."); + updateRPCPrefs(); } } }); - - enableInputView(true); + Button startRPC = findViewById(R.id.button_start_rpc); + startRPC.setOnClickListener(new View.OnClickListener() { + public void onClick(View v) { + Intent intent = ((MainActivity) context).updateRPCPrefs(); + startActivity(intent); + } + }); + + enableInputView(true); } @Override - protected void onDestroy() { - super.onDestroy(); - if (watchdog != null) { - watchdog.disconnect(); - watchdog = null; + protected void onResume() { + System.err.println("MainActivity onResume..."); + System.err.println("skipRelaunch: " + skipRelaunch); + // if this is the first time onResume is called, do nothing, otherwise we + // may double launch + if (!skipRelaunch) { + enableInputView(true); + setupRelaunch(); + } else { + skipRelaunch = false; } + super.onResume(); } - private void connectProxy() { - EditText edProxyAddress = findViewById(R.id.input_address); - EditText edProxyPort = findViewById(R.id.input_port); - EditText edAppKey = findViewById(R.id.input_key); - final String proxyHost = edProxyAddress.getText().toString(); - final int proxyPort = Integer.parseInt(edProxyPort.getText().toString()); - final String key = edAppKey.getText().toString(); - - System.err.println("creating watchdog thread..."); - watchdog = new RPCWatchdog(proxyHost, proxyPort, key, this); - - System.err.println("starting watchdog thread..."); - watchdog.start(); - - SharedPreferences pref = getApplicationContext().getSharedPreferences("RPCProxyPreference", Context.MODE_PRIVATE); - SharedPreferences.Editor editor = pref.edit(); - editor.putString("input_address", proxyHost); - editor.putString("input_port", edProxyPort.getText().toString()); - editor.putString("input_key", key); - editor.commit(); - } - - private void disconnect() { - if (watchdog != null) { - watchdog.disconnect(); - watchdog = null; - } + @Override + protected void onDestroy() { + super.onDestroy(); } private void enableInputView(boolean enable) { EditText edProxyAddress = findViewById(R.id.input_address); EditText edProxyPort = findViewById(R.id.input_port); EditText edAppKey = findViewById(R.id.input_key); + Switch input_switch = findViewById(R.id.switch_persistent); edProxyAddress.setEnabled(enable); edProxyPort.setEnabled(enable); edAppKey.setEnabled(enable); @@ -134,6 +174,8 @@ public class MainActivity extends AppCompatActivity { String inputKey = pref.getString("input_key", null); if (null != inputKey) edAppKey.setText(inputKey); + boolean isChecked = pref.getBoolean("input_switch", false); + input_switch.setChecked(isChecked); } } } diff --git a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCActivity.java b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..912a7c9e69a65943146f8c23d2647b4db81d6086 --- /dev/null +++ b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCActivity.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.tvmrpc; + +import android.os.Bundle; +import android.support.v7.app.AppCompatActivity; +import android.content.Intent; +import android.widget.Button; +import android.view.View; + +public class RPCActivity extends AppCompatActivity { + private RPCProcessor tvmServerWorker; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_rpc); + + Button stopRPC = findViewById(R.id.button_stop_rpc); + stopRPC.setOnClickListener(new View.OnClickListener() { + public void onClick(View v) { + System.err.println(tvmServerWorker == null); + if (tvmServerWorker != null) { + // currently will raise a socket closed exception + tvmServerWorker.disconnect(); + } + finish(); + // prevent Android from recycling the process + System.exit(0); + } + }); + + System.err.println("rpc activity onCreate..."); + Intent intent = getIntent(); + String host = intent.getStringExtra("host"); + int port = intent.getIntExtra("port", 9090); + String key = intent.getStringExtra("key"); + + tvmServerWorker = new RPCProcessor(); + tvmServerWorker.setDaemon(true); + tvmServerWorker.start(); + tvmServerWorker.connect(host, port, key); + } + + @Override + protected void onDestroy() { + System.err.println("rpc activity onDestroy"); + tvmServerWorker.disconnect(); + super.onDestroy(); + } +} diff --git a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCProcessor.java b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCProcessor.java index 4099084d4ebfec6945e87a55b9146db1917416ca..6da89093110460f49fc9a78b1073cf87bd1bd0e8 100644 --- a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCProcessor.java +++ b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCProcessor.java @@ -17,15 +17,11 @@ package ml.dmlc.tvm.tvmrpc; -import android.os.Bundle; -import android.os.Handler; -import android.os.Message; import android.os.ParcelFileDescriptor; - import java.net.Socket; - -import ml.dmlc.tvm.rpc.ConnectProxyServerProcessor; +import ml.dmlc.tvm.rpc.ConnectTrackerServerProcessor; import ml.dmlc.tvm.rpc.SocketFileDescriptorGetter; +import ml.dmlc.tvm.rpc.RPCWatchdog; /** * Connect to RPC proxy and deal with requests. @@ -36,9 +32,8 @@ class RPCProcessor extends Thread { private String key; private boolean running = false; private long startTime; - private ConnectProxyServerProcessor currProcessor; - private boolean kill = false; - public static final int SESSION_TIMEOUT = 30000; + private ConnectTrackerServerProcessor currProcessor; + private boolean first = true; static final SocketFileDescriptorGetter socketFdGetter = new SocketFileDescriptorGetter() { @@ -47,21 +42,10 @@ class RPCProcessor extends Thread { return ParcelFileDescriptor.fromSocket(socket).getFd(); } }; - // callback to initialize the start time of an rpc session - class setTimeCallback implements Runnable { - private RPCProcessor rPCProcessor; - - public setTimeCallback(RPCProcessor rPCProcessor) { - this.rPCProcessor = rPCProcessor; - } - - @Override - public void run() { - rPCProcessor.setStartTime(); - } - } @Override public void run() { + RPCWatchdog watchdog = new RPCWatchdog(); + watchdog.start(); while (true) { synchronized (this) { currProcessor = null; @@ -71,49 +55,18 @@ class RPCProcessor extends Thread { } catch (InterruptedException e) { } } - // if kill, we do nothing and wait for app restart - // to prevent race where timedOut was reported but restart has not - // happened yet - if (kill) { - System.err.println("waiting for restart..."); - currProcessor = null; - } - else { - startTime = 0; - currProcessor = new ConnectProxyServerProcessor(host, port, key, socketFdGetter); - currProcessor.setStartTimeCallback(new setTimeCallback(this)); + try { + currProcessor = new ConnectTrackerServerProcessor(host, port, key, socketFdGetter, watchdog); + } catch (Throwable e) { + e.printStackTrace(); + // kill if creating a new processor failed + System.exit(0); } } - if (currProcessor != null) - currProcessor.run(); - } - } - - /** - * check if the current RPCProcessor has timed out while in a session - */ - synchronized boolean timedOut(long curTime) { - if (startTime == 0) { - return false; + if (currProcessor != null) + currProcessor.run(); + watchdog.finishTimeout(); } - else if ((curTime - startTime) > SESSION_TIMEOUT) { - System.err.println("set kill flag..."); - kill = true; - return true; - } - return false; - } - - /** - * set the start time of the current RPC session (used in callback) - */ - synchronized void setStartTime() { - startTime = System.currentTimeMillis(); - System.err.println("start time set to: " + startTime); - } - - synchronized long getStartTime() { - return startTime; } /** @@ -139,6 +92,6 @@ class RPCProcessor extends Thread { this.port = port; this.key = key; running = true; - notify(); + this.notify(); } } diff --git a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCService.java b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCService.java deleted file mode 100644 index facc230f60ae78c2a0bfa7971c181485e1c70498..0000000000000000000000000000000000000000 --- a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCService.java +++ /dev/null @@ -1,64 +0,0 @@ -package ml.dmlc.tvm.tvmrpc; - -import android.app.Service; -import android.os.IBinder; -import android.content.Intent; - -public class RPCService extends Service { - private String host; - private int port; - private String key; - private int intentNum; - private RPCProcessor tvmServerWorker; - - @Override - public int onStartCommand(Intent intent, int flags, int startId) { - synchronized(this) { - System.err.println("start command intent"); - // use an alternate kill to prevent android from recycling the - // process - if (intent.getBooleanExtra("kill", false)) { - System.err.println("rpc service received kill..."); - System.exit(0); - } - - this.host = intent.getStringExtra("host"); - this.port = intent.getIntExtra("port", 9090); - this.key = intent.getStringExtra("key"); - System.err.println("got the following: " + this.host + ", " + this.port + ", " + this.key); - System.err.println("intent num: " + this.intentNum); - - if (tvmServerWorker == null) { - System.err.println("service created worker..."); - tvmServerWorker = new RPCProcessor(); - tvmServerWorker.setDaemon(true); - tvmServerWorker.start(); - tvmServerWorker.connect(this.host, this.port, this.key); - } - else if (tvmServerWorker.timedOut(System.currentTimeMillis())) { - System.err.println("rpc service timed out, killing self..."); - System.exit(0); - } - this.intentNum++; - } - // do not restart unless watchdog/app expliciltly does so - return START_NOT_STICKY; - } - - @Override - public IBinder onBind(Intent intent) { - System.err.println("rpc service got onBind, doing nothing..."); - return null; - } - - @Override - public void onCreate() { - System.err.println("rpc service onCreate..."); - } - - @Override - public void onDestroy() { - tvmServerWorker.disconnect(); - System.err.println("rpc service onDestroy..."); - } -} diff --git a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCWatchdog.java b/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCWatchdog.java deleted file mode 100644 index 548a2dfc0e7d4766ecd8cf42ac229189342b5a2b..0000000000000000000000000000000000000000 --- a/apps/android_rpc/app/src/main/java/ml/dmlc/tvm/tvmrpc/RPCWatchdog.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ml.dmlc.tvm.tvmrpc; - -import android.content.Context; -import android.content.Intent; - -/** - * Watchdog for RPCService - */ -class RPCWatchdog extends Thread { - public static final int WATCHDOG_POLL_INTERVAL = 5000; - private String host; - private int port; - private String key; - private Context context; - private boolean done = false; - - public RPCWatchdog(String host, int port, String key, Context context) { - super(); - this.host = host; - this.port = port; - this.key = key; - this.context = context; - } - - /** - * Polling loop to check on RPCService status - */ - @Override public void run() { - try { - while (true) { - synchronized (this) { - if (done) { - System.err.println("watchdog done, returning..."); - return; - } - else { - System.err.println("polling rpc service..."); - System.err.println("sending rpc service intent..."); - Intent intent = new Intent(context, RPCService.class); - intent.putExtra("host", host); - intent.putExtra("port", port); - intent.putExtra("key", key); - // will implicilty restart the service if it died - context.startService(intent); - } - } - Thread.sleep(WATCHDOG_POLL_INTERVAL); - } - } catch (InterruptedException e) { - } - } - - /** - * Disconnect from the proxy server. - */ - synchronized void disconnect() { - // kill service - System.err.println("watchdog disconnect call..."); - System.err.println("stopping rpc service..."); - done = true; - Intent intent = new Intent(context, RPCService.class); - intent.putExtra("kill", true); - context.startService(intent); - } -} diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index 09ee723a10d26e8d0552af0994b3dbc2d59d44af..c5509c1f450628fe2bf7d6eb0f10329398541b17 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -12,15 +12,14 @@ include $(config) # 1) armeabi is deprecated in NDK r16 and removed in r17 # 2) vulkan is not supported in armeabi APP_ABI := armeabi-v7a arm64-v8a x86 x86_64 mips - APP_STL := c++_static APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti -ifeq ($(USE_OPENCL), 1) +ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif -ifeq ($(USE_VULKAN), 1) +ifeq ($(USE_VULKAN), 1) APP_CPPFLAGS += -DTVM_VULKAN_RUNTIME=1 APP_LDFLAGS += -lvulkan endif diff --git a/apps/android_rpc/app/src/main/jni/make/config.mk b/apps/android_rpc/app/src/main/jni/make/config.mk index e57a907441b001eed81a0f9837fde640132b2e0a..c40ce4ba3ec7d5e8a96a86b3d056114f13846f19 100644 --- a/apps/android_rpc/app/src/main/jni/make/config.mk +++ b/apps/android_rpc/app/src/main/jni/make/config.mk @@ -14,7 +14,7 @@ #------------------------------------------------------------------------------- APP_ABI = all -APP_PLATFORM = android-17 +APP_PLATFORM = android-24 # whether enable OpenCL during compile USE_OPENCL = 0 diff --git a/apps/android_rpc/app/src/main/res/layout/activity_main.xml b/apps/android_rpc/app/src/main/res/layout/activity_main.xml index f617cf2a04bb6fe4d1559a13ea4ccbad5c900fec..53d48bbd60d9d8e632dc0f5ce4b166f9fac739e7 100644 --- a/apps/android_rpc/app/src/main/res/layout/activity_main.xml +++ b/apps/android_rpc/app/src/main/res/layout/activity_main.xml @@ -24,4 +24,3 @@ <include layout="@layout/content_main"/> </android.support.design.widget.CoordinatorLayout> - diff --git a/apps/android_rpc/app/src/main/res/layout/activity_rpc.xml b/apps/android_rpc/app/src/main/res/layout/activity_rpc.xml new file mode 100644 index 0000000000000000000000000000000000000000..ba3102a6033cc1a87ae637ad6dbc5aaa23894b27 --- /dev/null +++ b/apps/android_rpc/app/src/main/res/layout/activity_rpc.xml @@ -0,0 +1,26 @@ +<?xml version="1.0" encoding="utf-8"?> +<android.support.design.widget.CoordinatorLayout + xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:app="http://schemas.android.com/apk/res-auto" + xmlns:tools="http://schemas.android.com/tools" + android:layout_width="match_parent" + android:layout_height="match_parent" + tools:context="ml.dmlc.tvm.tvmrpc.RPCActivity"> + + <android.support.design.widget.AppBarLayout + android:layout_height="wrap_content" + android:layout_width="match_parent" + android:theme="@style/AppTheme.AppBarOverlay"> + + <android.support.v7.widget.Toolbar + android:id="@+id/toolbar" + android:layout_width="match_parent" + android:layout_height="?attr/actionBarSize" + android:background="?attr/colorPrimary" + app:popupTheme="@style/AppTheme.PopupOverlay" /> + + </android.support.design.widget.AppBarLayout> + + <include layout="@layout/content_rpc"/> + +</android.support.design.widget.CoordinatorLayout> diff --git a/apps/android_rpc/app/src/main/res/layout/content_main.xml b/apps/android_rpc/app/src/main/res/layout/content_main.xml index 827cdfb01b8afc317b20b3f61382ae680fd83307..0f2564833ecda9d005c5d3f4944d36da5462969b 100644 --- a/apps/android_rpc/app/src/main/res/layout/content_main.xml +++ b/apps/android_rpc/app/src/main/res/layout/content_main.xml @@ -64,9 +64,9 @@ <TextView android:layout_width="wrap_content" android:layout_height="wrap_content" - android:text="@string/label_connect"/> + android:text="@string/label_persistent"/> <Switch - android:id="@+id/switch_connect" + android:id="@+id/switch_persistent" android:layout_width="wrap_content" android:layout_height="wrap_content" android:switchMinWidth="55dp" @@ -76,4 +76,15 @@ android:textOn="@string/switch_on" /> </LinearLayout> + <LinearLayout + android:orientation="horizontal" + android:layout_width="fill_parent" + android:layout_height="wrap_content"> + <Button + android:id="@+id/button_start_rpc" + android:layout_height="wrap_content" + android:layout_width="wrap_content" + android:text="@string/start_rpc" /> + </LinearLayout> + </LinearLayout> diff --git a/apps/android_rpc/app/src/main/res/layout/content_rpc.xml b/apps/android_rpc/app/src/main/res/layout/content_rpc.xml new file mode 100644 index 0000000000000000000000000000000000000000..fb9ab2f66a9bb0133959d53430cafe541ff1a2bd --- /dev/null +++ b/apps/android_rpc/app/src/main/res/layout/content_rpc.xml @@ -0,0 +1,14 @@ +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:tools="http://schemas.android.com/tools" + xmlns:app="http://schemas.android.com/apk/res-auto" + android:orientation="vertical" + android:layout_width="fill_parent" + android:layout_height="wrap_content" + app:layout_behavior="@string/appbar_scrolling_view_behavior" + tools:showIn="@layout/activity_rpc"> + <Button + android:id="@+id/button_stop_rpc" + android:layout_height="wrap_content" + android:layout_width="wrap_content" + android:text="@string/stop_rpc" /> +</LinearLayout> diff --git a/apps/android_rpc/app/src/main/res/values/strings.xml b/apps/android_rpc/app/src/main/res/values/strings.xml index 468fbed8ceaa346324fec64638b7c871dfa9872b..33caa374b4969a423f3a6dc71421b61558c7c8a4 100644 --- a/apps/android_rpc/app/src/main/res/values/strings.xml +++ b/apps/android_rpc/app/src/main/res/values/strings.xml @@ -1,15 +1,19 @@ <resources> <string name="app_name">TVM RPC</string> + <string name="rpc_name">RPC</string> - <string name="input_address">Enter the proxy server address</string> - <string name="input_port">Enter the proxy server port</string> + <string name="input_address">Enter the tracker server address</string> + <string name="input_port">Enter the tracker server port</string> <string name="input_key">Enter the app connection key</string> <string name="label_address">Address</string> <string name="label_port">Port</string> <string name="label_key">Key</string> - <string name="label_connect">Connect to Proxy</string> + <string name="label_persistent">Keep RPC Alive</string> - <string name="switch_on">Connected</string> - <string name="switch_off">Disconnected</string> + <string name="switch_on">Enabled</string> + <string name="switch_off">Disabled</string> + + <string name="start_rpc">Start RPC</string> + <string name="stop_rpc">Stop RPC</string> </resources> diff --git a/apps/android_rpc/tests/android_rpc_test.py b/apps/android_rpc/tests/android_rpc_test.py index a3362fafa63aff8606d4ddcd8e799b67bba9b7db..82575e656284930f74d9eb0d7b36c52cdf20d6f7 100644 --- a/apps/android_rpc/tests/android_rpc_test.py +++ b/apps/android_rpc/tests/android_rpc_test.py @@ -11,8 +11,8 @@ from tvm.contrib import util, ndk import numpy as np # Set to be address of tvm proxy. -proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"] -proxy_port = 9090 +tracker_host = os.environ["TVM_TRACKER_HOST"] +tracker_port = int(os.environ["TVM_TRACKER_PORT"]) key = "android" # Change target configuration. @@ -33,7 +33,7 @@ def test_rpc_module(): # Build the dynamic lib. # If we don't want to do metal and only use cpu, just set target to be target f = tvm.build(s, [A, B], "opencl", target_host=target, name="myadd") - path_dso1 = temp.relpath("dev_lib.so") + path_dso1 = temp.relpath("dev_lib2.so") f.export_library(path_dso1, ndk.create_shared) s = tvm.create_schedule(B.op) @@ -45,29 +45,31 @@ def test_rpc_module(): path_dso2 = temp.relpath("cpu_lib.so") f.export_library(path_dso2, ndk.create_shared) - # connect to the proxy - remote = rpc.connect(proxy_host, proxy_port, key=key) + tracker = rpc.connect_tracker(tracker_host, tracker_port) + remote = tracker.request(key, priority=0, + session_timeout=60) - print('Run GPU test ...') - ctx = remote.cl(0) - remote.upload(path_dso1) - f1 = remote.load_module("dev_lib.so") + print('Run CPU test ...') + ctx = remote.cpu(0) + remote.upload(path_dso2) + f2 = remote.load_module("cpu_lib.so") a_np = np.random.uniform(size=1024).astype(A.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f1.time_evaluator(f1.entry_name, ctx, number=10) + time_f = f2.time_evaluator(f2.entry_name, ctx, number=10) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - print('Run CPU test ...') - ctx = remote.cpu(0) - remote.upload(path_dso2) - f2 = remote.load_module("cpu_lib.so") + + print('Run GPU test ...') + ctx = remote.cl(0) + remote.upload(path_dso1) + f1 = remote.load_module("dev_lib2.so") a_np = np.random.uniform(size=1024).astype(A.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f2.time_evaluator(f2.entry_name, ctx, number=10) + time_f = f1.time_evaluator(f1.entry_name, ctx, number=10) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java new file mode 100644 index 0000000000000000000000000000000000000000..ca3af923c448260a77d59a54159d397eaea3c5f9 --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.BindException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.SocketTimeoutException; + + +/** + * Server processor with tracker connection (based on standalone). + * This RPC Server registers itself with an RPC Tracker for a specific queue + * (using its device key) and listens for incoming requests. + */ +public class ConnectTrackerServerProcessor implements ServerProcessor { + private ServerSocket server; + private final SocketFileDescriptorGetter socketFileDescriptorGetter; + private final String trackerHost; + private final int trackerPort; + // device key + private final String key; + // device key plus randomly generated key (per-session) + private final String matchKey; + private int serverPort = 5001; + public static final int MAX_SERVER_PORT = 5555; + // time to wait before aborting tracker connection (ms) + public static final int TRACKER_TIMEOUT = 6000; + // time to wait before retrying tracker connection (ms) + public static final int RETRY_PERIOD = TRACKER_TIMEOUT; + // time to wait for a connection before refreshing tracker connection (ms) + public static final int STALE_TRACKER_TIMEOUT = 300000; + // time to wait if no timeout value is specified (seconds) + public static final int HARD_TIMEOUT_DEFAULT = 300; + private RPCWatchdog watchdog; + private Socket trackerSocket; + + /** + * Construct tracker server processor. + * @param trackerHost Tracker host. + * @param trackerPort Tracker port. + * @param key Device key. + * @param sockFdGetter Method to get file descriptor from Java socket. + */ + public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key, + SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException { + while (true) { + try { + this.server = new ServerSocket(serverPort); + server.setSoTimeout(STALE_TRACKER_TIMEOUT); + break; + } catch (BindException e) { + System.err.println(serverPort); + System.err.println(e); + serverPort++; + if (serverPort > MAX_SERVER_PORT) { + throw e; + } + } + } + System.err.println("using port: " + serverPort); + this.socketFileDescriptorGetter = sockFdGetter; + this.trackerHost = trackerHost; + this.trackerPort = trackerPort; + this.key = key; + this.matchKey = key + ":" + Math.random(); + this.watchdog = watchdog; + } + + public String getMatchKey() { + return matchKey; + } + + @Override public void terminate() { + try { + server.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Override public void run() { + String recvKey = null; + try { + trackerSocket = connectToTracker(); + // open a socket and handshake with tracker + register(); + Socket socket = null; + InputStream in = null; + OutputStream out = null; + while (true) { + try { + System.err.println("waiting for requests..."); + // wait for client request + socket = server.accept(); + in = socket.getInputStream(); + out = socket.getOutputStream(); + int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt(); + if (magic != RPC.RPC_MAGIC) { + out.write(Utils.toBytes(RPC.RPC_CODE_MISMATCH)); + System.err.println("incorrect RPC magic"); + Utils.closeQuietly(socket); + continue; + } + recvKey = Utils.recvString(in); + System.err.println("matchKey:" + matchKey); + System.err.println("key: " + recvKey); + // incorrect key + if (recvKey.indexOf(matchKey) == -1) { + out.write(Utils.toBytes(RPC.RPC_CODE_MISMATCH)); + System.err.println("key mismatch, expected: " + matchKey + " got: " + recvKey); + Utils.closeQuietly(socket); + continue; + } + // successfully got client request and completed handshake with client + break; + } catch (SocketTimeoutException e) { + System.err.println("no incoming connections, refreshing..."); + // need to reregister, if the tracker died we should see a socked closed exception + if (!needRefreshKey()) { + System.err.println("reregistering..."); + register(); + } + } + } + int timeout = HARD_TIMEOUT_DEFAULT; + int timeoutArgIndex = recvKey.indexOf(RPC.TIMEOUT_ARG); + if (timeoutArgIndex != -1) { + timeout = Integer.parseInt(recvKey.substring(timeoutArgIndex + RPC.TIMEOUT_ARG.length())); + } + System.err.println("alloted timeout: " + timeout); + if (!recvKey.startsWith("client:")) { + System.err.println("recv key mismatch..."); + out.write(Utils.toBytes(RPC.RPC_CODE_MISMATCH)); + } else { + out.write(Utils.toBytes(RPC.RPC_MAGIC)); + // send server key to the client + Utils.sendString(out, recvKey); + } + + System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); + // received timeout in seconds + watchdog.startTimeout(timeout * 1000); + final int sockFd = socketFileDescriptorGetter.get(socket); + if (sockFd != -1) { + new NativeServerLoop(sockFd).run(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); + } + Utils.closeQuietly(socket); + } catch (ConnectException e) { + // if the tracker connection failed, wait a bit before retrying + try { + Thread.sleep(RETRY_PERIOD); + } catch (InterruptedException e_) { + System.err.println("interrupted before retrying to connect to tracker..."); + } + } catch (Throwable e) { + e.printStackTrace(); + } finally { + try { + if (trackerSocket != null) { + trackerSocket.close(); + } + server.close(); + } catch (Throwable e) { + e.printStackTrace(); + } + } + } + + private Socket connectToTracker() throws IOException { + trackerSocket = new Socket(); + SocketAddress address = new InetSocketAddress(trackerHost, trackerPort); + trackerSocket.connect(address, TRACKER_TIMEOUT); + InputStream trackerIn = trackerSocket.getInputStream(); + OutputStream trackerOut = trackerSocket.getOutputStream(); + trackerOut.write(Utils.toBytes(RPC.RPC_TRACKER_MAGIC)); + int trackerMagic = Utils.wrapBytes(Utils.recvAll(trackerIn, 4)).getInt(); + if (trackerMagic != RPC.RPC_TRACKER_MAGIC) { + throw new SocketException("failed to connect to tracker (WRONG MAGIC)"); + } + return trackerSocket; + } + + /* + * Register the RPC Server with the RPC Tracker. + */ + private void register() throws IOException { + InputStream trackerIn = trackerSocket.getInputStream(); + OutputStream trackerOut = trackerSocket.getOutputStream(); + // send a JSON with PUT, device key, RPC server port, and the randomly + // generated key + String putJSON = generatePut(RPC.TrackerCode.PUT, key, serverPort, matchKey); + Utils.sendString(trackerOut, putJSON); + int recvCode = Integer.parseInt(Utils.recvString(trackerIn)); + if (recvCode != RPC.TrackerCode.SUCCESS) { + throw new SocketException("failed to register with tracker (not SUCCESS)"); + } + System.err.println("registered with tracker..."); + } + + /* + * Check if the RPC Tracker has our key. + */ + private boolean needRefreshKey() throws IOException { + InputStream trackerIn = trackerSocket.getInputStream(); + OutputStream trackerOut = trackerSocket.getOutputStream(); + String getJSON = generateGetPendingMatchKeys(RPC.TrackerCode.GET_PENDING_MATCHKEYS); + Utils.sendString(trackerOut, getJSON); + String recvJSON = Utils.recvString(trackerIn); + System.err.println("pending matchkeys: " + recvJSON); + // fairly expensive string operation... + if (recvJSON.indexOf(matchKey) != -1 ) { + return true; + } + return false; + } + + // handcrafted JSON + private String generatePut(int code, String key, int port, String matchKey) { + return "[" + code + ", " + "\"" + key + "\"" + ", " + "[" + port + ", " + + "\"" + matchKey + "\"" + "]" + ", " + "null" + "]"; + } + + // handcrafted JSON + private String generateGetPendingMatchKeys(int code) { + return "[" + code + "]"; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java index 4bc4c34c2d0abc2b466a90810d65112aa2f02057..666b15aed6154b54f2098e35dad4ec5e0f6de25c 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java @@ -42,7 +42,9 @@ public class NativeServerLoop implements Runnable { File tempDir = null; try { tempDir = serverEnv(); + System.err.println("starting server loop..."); RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); + System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); } finally { diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java index 225bcaeefbd32e34196090bfdfb9e9340d392b5d..757fc0df32650c20cf37ee74fe75d298ff8c5d83 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java @@ -23,9 +23,19 @@ import java.util.HashMap; import java.util.Map; public class RPC { + public static final int RPC_TRACKER_MAGIC = 0x2f271; public static final int RPC_MAGIC = 0xff271; + public static final int RPC_CODE_MISMATCH = RPC_MAGIC + 2; public static final int RPC_SESS_MASK = 128; + public static final String TIMEOUT_ARG = "-timeout="; + + public class TrackerCode { + public static final int PUT = 3; + public static final int GET_PENDING_MATCHKEYS = 7; + public static final int SUCCESS = 0; + } + private static ThreadLocal<Map<String, Function>> apiFuncs = new ThreadLocal<Map<String, Function>>() { @Override diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCWatchdog.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCWatchdog.java new file mode 100644 index 0000000000000000000000000000000000000000..4df858cbd6bbadf701f840d04ddc6d94318c2ada --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCWatchdog.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +/** + * Watchdog for RPC. + */ +public class RPCWatchdog extends Thread { + private int timeout = 0; + private boolean started = false; + + public RPCWatchdog() { + super(); + } + + /** + * Start a timeout with watchdog (must be called before finishTimeout). + * @param timeout watchdog timeout in ms. + */ + public synchronized void startTimeout(int timeout) { + this.timeout = timeout; + started = true; + this.notify(); + } + + /** + * Finish a timeout with watchdog (must be called after startTimeout). + */ + public synchronized void finishTimeout() { + started = false; + this.notify(); + } + + /** + * Wait and kill RPC if timeout is exceeded. + */ + @Override public void run() { + while (true) { + // timeout not started + synchronized (this) { + while (!started) { + try { + this.wait(); + } catch (InterruptedException e) { + System.err.println("watchdog interrupted..."); + } + } + } + synchronized (this) { + while (started) { + try { + System.err.println("waiting for timeout: " + timeout); + this.wait(timeout); + if (!started) { + System.err.println("watchdog woken up, ok..."); + } else { + System.err.println("watchdog woke up!"); + System.err.println("terminating..."); + System.exit(0); + } + } catch (InterruptedException e) { + System.err.println("watchdog interrupted..."); + } + } + } + } + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Utils.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Utils.java index d6d9efb6e6042b89abd222c012ec8f33baac53ce..0f241d12c5589245562122acad0bbff8e8a87f47 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Utils.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Utils.java @@ -19,6 +19,7 @@ package ml.dmlc.tvm.rpc; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.net.Socket; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -76,4 +77,16 @@ class Utils { } return builder.toString(); } + + public static String recvString(InputStream in) throws IOException { + String recvString = null; + int len = wrapBytes(Utils.recvAll(in, 4)).getInt(); + recvString = decodeToStr(Utils.recvAll(in, len)); + return recvString; + } + + public static void sendString(OutputStream out, String string) throws IOException { + out.write(toBytes(string.length())); + out.write(toBytes(string)); + } }