Commit 71cb07ae by Tianqi Chen Committed by GitHub

[RPC] IOS RPC (#261)

parent 9037a4c2
...@@ -154,12 +154,12 @@ verilog: $(VER_LIBS) ...@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules. # Special rules for LLVM related modules.
build/codegen/llvm/%.o: src/codegen/llvm/%.cc build/codegen/llvm/%.o: src/codegen/llvm/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/codegen/llvm/$*.o $< >build/codegen/llvm/$*.d $(CXX) $(CFLAGS) $(LLVM_CFLAGS) -MM -MT build/codegen/llvm/$*.o $< >build/codegen/llvm/$*.d
$(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@
build/runtime/metal/%.o: src/runtime/metal/%.mm build/runtime/metal/%.o: src/runtime/metal/%.mm
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d $(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/%.o: src/%.cc build/%.o: src/%.cc
...@@ -233,3 +233,4 @@ clean: ...@@ -233,3 +233,4 @@ clean:
-include build/*.d -include build/*.d
-include build/*/*.d -include build/*/*.d
-include build/*/*/*.d -include build/*/*/*.d
-include build/*/*/*/*.d
# iOS TVM RPC
This folder contains iOS RPC app that allows us to launch an rpc server on a iOS device(e.g. ipython)
and connect to it through python script and do testing on the python side as normal TVM RPC.
You will need XCode and an iOS device to use this.
## Workflow
Due to security restriction of iOS10. We cannot upload dynamic libraries to the App and load it from sandbox.
Instead, we need to build a list of libraries, pack them into the app bundle, launch the RPC server and
connect to test the bundled libraries. We use ```xcodebuild test``` to automate this process.
See [tests/ios_rpc_test.py](tests/ios_rpc_test.py) for an example.
## Environment Variables
To use the utilities, you need to configure the following environment variables
- ```TVM_IOS_CODESIGN``` The signature you use to codesign the app and libraries(e.g. ```iPhone Developer: Name (XXXX)```)
- ```TVM_IOS_RPC_ROOT``` The root directory of the iOS rpc project
## Launch RPC from XCode IDE
Let us first explain how it works, the project look for ```rpc_config.txt``` file in the project root folder.
The ```rpc_config.txt``` file should be in the following format:
```
<url> <port> <key>
[path to dylib1]
[path to dylib2]
...
```
The build script will copy all the dynamic libraries into bundle ```tvmrpc.app/Frameworks/tvm```,
which you will be able to load via RPC using ```remote.load_module```.
It will also create an ```tvmrpc.app/Frameworks/tvm/rpc_config.txt``` contaiing the first line.
When we run the testcase, the testcase read the configuration from ```tvmrpc.app/Frameworks/tvm/rpc_config.txt```
and connect to the specified RPC proxy, start serving loop.
So if we want to start the RPC from XCode IDE, simply manually modify ```rpc_config.txt``` file and click test.
Then connect to the proxy via the python script.
## Use RPC via App
We can also use the RPC App directly, by typing in the address and press connect to connect to the proxy.
However, the restriction is we can only load the modules that are bundled to the App.
"""Testcode for iOS RPC.
To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy".
And configure the proxy host field as commented.
"""
import tvm
import os
from tvm.contrib import rpc, util, xcode
import numpy as np
# Set to be address of tvm proxy.
proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"]
# Set your desination via env variable.
# Should in format "platform=iOS,id=<the test device uuid>"
destination = os.environ["TVM_IOS_RPC_DESTINATION"]
proxy_port = 9090
key = "iphone"
# Change target configuration, this is setting for iphone6s
arch = "arm64"
sdk = "iphoneos"
target = "llvm -target=%s-apple-darwin" % arch
# override metal compiler to compile to iphone
@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src):
return xcode.compile_metal(src, sdk=sdk)
def test_rpc_module():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=64)
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
temp = util.tempdir()
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd")
path_dso = temp.relpath("dev_lib.dylib")
f.export_library(path_dso, xcode.create_dylib,
arch=arch, sdk=sdk)
xcode.codesign(path_dso)
# Start RPC test server that contains the compiled library.
server = xcode.popen_test_rpc(proxy_host, proxy_port, key,
destination=destination,
libs=[path_dso],
options=["-quiet"])
# connect to the proxy
remote = rpc.connect(proxy_host, proxy_port, key=key)
ctx = remote.metal(0)
f1 = remote.load_module("dev_lib.dylib")
a_np = np.random.uniform(size=1024).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
test_rpc_module()
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:tvmrpc.xcodeproj">
</FileRef>
</Workspace>
/*!
* Copyright (c) 2017 by Contributors
* \file AppDelegate.h
*/
#import <UIKit/UIKit.h>
@interface AppDelegate : UIResponder <UIApplicationDelegate>
@property (strong, nonatomic) UIWindow *window;
@end
/*!
* Copyright (c) 2017 by Contributors
* \file AppDelegate.mm
*/
#import "AppDelegate.h"
@interface AppDelegate ()
@end
@implementation AppDelegate
- (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
// Override point for customization after application launch.
return YES;
}
- (void)applicationWillResignActive:(UIApplication *)application {
// Sent when the application is about to move from active to inactive state. This can occur for certain types of temporary interruptions (such as an incoming phone call or SMS message) or when the user quits the application and it begins the transition to the background state.
// Use this method to pause ongoing tasks, disable timers, and invalidate graphics rendering callbacks. Games should use this method to pause the game.
}
- (void)applicationDidEnterBackground:(UIApplication *)application {
// Use this method to release shared resources, save user data, invalidate timers, and store enough application state information to restore your application to its current state in case it is terminated later.
// If your application supports background execution, this method is called instead of applicationWillTerminate: when the user quits.
}
- (void)applicationWillEnterForeground:(UIApplication *)application {
// Called as part of the transition from the background to the active state; here you can undo many of the changes made on entering the background.
}
- (void)applicationDidBecomeActive:(UIApplication *)application {
// Restart any tasks that were paused (or not yet started) while the application was inactive. If the application was previously in the background, optionally refresh the user interface.
}
- (void)applicationWillTerminate:(UIApplication *)application {
// Called when the application is about to terminate. Save data if appropriate. See also applicationDidEnterBackground:.
}
@end
{
"images" : [
{
"idiom" : "iphone",
"size" : "20x20",
"scale" : "2x"
},
{
"idiom" : "iphone",
"size" : "20x20",
"scale" : "3x"
},
{
"idiom" : "iphone",
"size" : "29x29",
"scale" : "2x"
},
{
"idiom" : "iphone",
"size" : "29x29",
"scale" : "3x"
},
{
"idiom" : "iphone",
"size" : "40x40",
"scale" : "2x"
},
{
"idiom" : "iphone",
"size" : "40x40",
"scale" : "3x"
},
{
"idiom" : "iphone",
"size" : "60x60",
"scale" : "2x"
},
{
"idiom" : "iphone",
"size" : "60x60",
"scale" : "3x"
},
{
"idiom" : "ipad",
"size" : "20x20",
"scale" : "1x"
},
{
"idiom" : "ipad",
"size" : "20x20",
"scale" : "2x"
},
{
"idiom" : "ipad",
"size" : "29x29",
"scale" : "1x"
},
{
"idiom" : "ipad",
"size" : "29x29",
"scale" : "2x"
},
{
"idiom" : "ipad",
"size" : "40x40",
"scale" : "1x"
},
{
"idiom" : "ipad",
"size" : "40x40",
"scale" : "2x"
},
{
"idiom" : "ipad",
"size" : "76x76",
"scale" : "1x"
},
{
"idiom" : "ipad",
"size" : "76x76",
"scale" : "2x"
},
{
"idiom" : "ipad",
"size" : "83.5x83.5",
"scale" : "2x"
}
],
"info" : {
"version" : 1,
"author" : "xcode"
}
}
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="11134" systemVersion="15F34" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" launchScreen="YES" useTraitCollections="YES" colorMatched="YES" initialViewController="01J-lp-oVM">
<dependencies>
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="11106"/>
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
</dependencies>
<scenes>
<!--View Controller-->
<scene sceneID="EHf-IW-A2E">
<objects>
<viewController id="01J-lp-oVM" sceneMemberID="viewController">
<layoutGuides>
<viewControllerLayoutGuide type="top" id="Llm-lL-Icb"/>
<viewControllerLayoutGuide type="bottom" id="xb3-aO-Qok"/>
</layoutGuides>
<view key="view" contentMode="scaleToFill" id="Ze5-6b-2t3">
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
<color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
</view>
</viewController>
<placeholder placeholderIdentifier="IBFirstResponder" id="iYj-Kq-Ea1" userLabel="First Responder" sceneMemberID="firstResponder"/>
</objects>
<point key="canvasLocation" x="53" y="375"/>
</scene>
</scenes>
</document>
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleExecutable</key>
<string>$(EXECUTABLE_NAME)</string>
<key>CFBundleIdentifier</key>
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>$(PRODUCT_NAME)</string>
<key>CFBundlePackageType</key>
<string>APPL</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleVersion</key>
<string>1</string>
<key>LSRequiresIPhoneOS</key>
<true/>
<key>UILaunchStoryboardName</key>
<string>LaunchScreen</string>
<key>UIMainStoryboardFile</key>
<string>Main</string>
<key>UIRequiredDeviceCapabilities</key>
<array>
<string>armv7</string>
</array>
<key>UISupportedInterfaceOrientations</key>
<array>
<string>UIInterfaceOrientationPortrait</string>
<string>UIInterfaceOrientationLandscapeLeft</string>
<string>UIInterfaceOrientationLandscapeRight</string>
</array>
<key>UISupportedInterfaceOrientations~ipad</key>
<array>
<string>UIInterfaceOrientationPortrait</string>
<string>UIInterfaceOrientationPortraitUpsideDown</string>
<string>UIInterfaceOrientationLandscapeLeft</string>
<string>UIInterfaceOrientationLandscapeRight</string>
</array>
</dict>
</plist>
/*!
* Copyright (c) 2017 by Contributors
* \file TVMRuntime.h
*/
#import <Foundation/Foundation.h>
// Customize logging mechanism, redirect to NSLOG
#define DMLC_LOG_CUSTOMIZE 1
#define TVM_METAL_RUNTIME 1
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <functional>
namespace tvm {
namespace runtime {
/*!
* \brief Message handling function for event driven server.
*
* \param in_bytes The incoming bytes.
* \param event_flag 1: read_available, 2: write_avaiable.
* \return State flag.
* 1: continue running, no need to write,
* 2: need to write
* 0: shutdown
*/
using FEventHandler = std::function<int(const std::string& in_bytes, int event_flag)>;
/*!
* \brief Create a server event handler.
*
* \param outputStream The output stream used to send outputs.
* \param name The name of the server.
* \return The event handler.
*/
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream, std::string name);
} // namespace runtime
} // namespace tvm
@interface TVMRuntime : NSObject
+ (void)launchSyncServer;
@end
/*!
* Copyright (c) 2017 by Contributors
* \file TVMRuntime.mm
*/
#include "TVMRuntime.h"
// Runtime API
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc"
#include "../../src/runtime/file_util.cc"
#include "../../src/runtime/dso_module.cc"
// RPC server
#include "../../src/runtime/rpc/rpc_session.cc"
#include "../../src/runtime/rpc/rpc_server_env.cc"
#include "../../src/runtime/rpc/rpc_socket_impl.cc"
#include "../../src/runtime/rpc/rpc_module.cc"
// Metal
#include "../../src/runtime/metal/metal_module.mm"
#include "../../src/runtime/metal/metal_device_api.mm"
namespace dmlc {
// Override logging mechanism
void CustomLogMessage::Log(const std::string& msg) {
NSLog(@"%s", msg.c_str());
}
} // namespace dmlc
namespace tvm {
namespace runtime {
class NSStreamChannel final : public RPCChannel {
public:
explicit NSStreamChannel(NSOutputStream* stream)
: stream_(stream) {}
size_t Send(const void* data, size_t size) final {
ssize_t nbytes = [stream_ write:reinterpret_cast<const uint8_t*>(data)
maxLength:size];
if (nbytes < 0) {
NSLog(@"%@",[stream_ streamError].localizedDescription);
throw dmlc::Error("Stream error");
}
return nbytes;
}
size_t Recv(void* data, size_t size) final {
LOG(FATAL) << "Do not allow explicit receive for";
return 0;
}
private:
NSOutputStream* stream_;
};
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream, std::string name) {
std::unique_ptr<NSStreamChannel> ch(new NSStreamChannel(outputStream));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
return [sess](const std::string& in_bytes, int flag) {
return sess->ServerEventHandler(in_bytes, flag);
};
}
// Runtime environment
struct RPCEnv {
public:
RPCEnv() {
NSString* path = NSTemporaryDirectory();
base_ = [path UTF8String];
if (base_[base_.length() - 1] != '/') {
base_ = base_ + '/';
}
}
// Get Path.
std::string GetPath(const std::string& file_name) {
return base_ + file_name;
}
private:
std::string base_;
};
void LaunchSyncServer() {
// only load dylib from frameworks.
NSBundle* bundle = [NSBundle mainBundle];
NSString* base = [bundle privateFrameworksPath];
NSString* path = [base stringByAppendingPathComponent: @"tvm/rpc_config.txt"];
std::string name = [path UTF8String];
std::ifstream fs(name, std::ios::in);
std::string url, key;
int port;
CHECK(fs >> url >> port >> key)
<< "Invalid RPC config file " << name;
RPCConnect(url, port, "server:" + key)
->ServerLoop();
}
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string name = args[0];
std::string fmt = GetFileFormat(name, "");
NSString* base;
if (fmt == "dylib") {
// only load dylib from frameworks.
NSBundle* bundle = [NSBundle mainBundle];
base = [[bundle privateFrameworksPath]
stringByAppendingPathComponent: @"tvm"];
} else {
// Load other modules in tempdir.
base = NSTemporaryDirectory();
}
NSString* path = [base stringByAppendingPathComponent:
[NSString stringWithUTF8String:name.c_str()]];
name = [path UTF8String];
*rv = Module::LoadFromFile(name, fmt);
LOG(INFO) << "Load module from " << name << " ...";
});
} // namespace runtime
} // namespace tvm
@implementation TVMRuntime
+(void) launchSyncServer {
tvm::runtime::LaunchSyncServer();
}
@end
/*!
* Copyright (c) 2017 by Contributors
* \file ViewController.h
*/
#import <UIKit/UIKit.h>
#include "TVMRuntime.h"
@interface ViewController : UIViewController<NSStreamDelegate>
{
// input socket stream
NSInputStream *inputStream_;
// output socket stream
NSOutputStream *outputStream_;
// temporal receive buffer.
std::string recvBuffer_;
// Whether connection is initialized.
bool initialized_;
// Whether auto reconnect when a session is done.
bool auto_reconnect_;
// The key of the server.
std::string key_;
// Initial bytes to be send to remote
std::string initBytes_;
// Send pointer of initial bytes.
size_t initSendPtr_;
// Event handler.
tvm::runtime::FEventHandler handler_;
}
@property (weak, nonatomic) IBOutlet UITextField *proxyURL;
@property (weak, nonatomic) IBOutlet UITextField *proxyPort;
@property (weak, nonatomic) IBOutlet UITextField *proxyKey;
@property (weak, nonatomic) IBOutlet UILabel *statusLabel;
@property (weak, nonatomic) IBOutlet UITextView *infoText;
- (IBAction)connect:(id)sender;
- (IBAction)disconnect:(id)sender;
@end
/*!
* Copyright (c) 2017 by Contributors
* \file ViewController.mm
*/
#include <string>
#import "ViewController.h"
@implementation ViewController
- (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event {
std::string buffer;
switch (event) {
case NSStreamEventOpenCompleted: {
self.statusLabel.text = @"Connected";
break;
}
case NSStreamEventHasBytesAvailable:
if (strm == inputStream_) {
[self onReadAvailable];
}
break;
case NSStreamEventHasSpaceAvailable: {
if (strm == outputStream_) {
[self onWriteAvailable];
}
break;
}
case NSStreamEventErrorOccurred: {
NSLog(@"%@",[strm streamError].localizedDescription);
break;
}
case NSStreamEventEndEncountered: {
[self close];
// auto reconnect when normal end.
[self open];
break;
}
default: {
NSLog(@"Unknown event");
}
}
}
- (void)onReadAvailable {
constexpr int kRPCMagic = 0xff271;
if (!initialized_) {
int code;
size_t nbytes = [inputStream_ read:reinterpret_cast<uint8_t*>(&code)
maxLength:sizeof(code)];
if (nbytes != sizeof(code)) {
self.infoText.text = @"Fail to receive remote confirmation code.";
[self close];
} else if (code == kRPCMagic + 2) {
self.infoText.text = @"Proxy server cannot find client that matches the key";
[self close];
} else if (code == kRPCMagic + 1) {
self.infoText.text = @"Proxy server already have another server with same key";
[self close];
} else if (code != kRPCMagic) {
self.infoText.text = @"Given address is not a TVM RPC Proxy";
[self close];
} else {
initialized_ = true;
self.statusLabel.text = @"Proxy connected.";
CHECK(handler_ != nullptr);
}
}
const int kBufferSize = 4 << 10;
if (initialized_) {
while ([inputStream_ hasBytesAvailable]) {
recvBuffer_.resize(kBufferSize);
uint8_t* bptr = reinterpret_cast<uint8_t*>(&recvBuffer_[0]);
size_t nbytes = [inputStream_ read:bptr maxLength:kBufferSize];
recvBuffer_.resize(nbytes);
int flag = 1;
if ([outputStream_ hasSpaceAvailable]) {
flag |= 2;
}
// always try to write
try {
flag = handler_(recvBuffer_, flag);
if (flag == 2) {
[self onShutdownReceived];
}
} catch (const dmlc::Error& e) {
[self close];
}
}
}
}
- (void)onShutdownReceived {
[self close];
}
- (void)onWriteAvailable {
if (initSendPtr_ < initBytes_.length()) {
initSendPtr_ += [outputStream_ write:reinterpret_cast<uint8_t*>(&initBytes_[initSendPtr_])
maxLength:(initBytes_.length() - initSendPtr_)];
}
if (initialized_) {
try {
std::string dummy;
int flag = handler_(dummy, 2);
if (flag == 2) {
[self onShutdownReceived];
}
} catch (const dmlc::Error& e) {
[self close];
}
}
}
- (void)open {
constexpr int kRPCMagic = 0xff271;
NSLog(@"Connecting to the proxy server..");
// Initialize the data states.
key_ = [self.proxyKey.text UTF8String];
key_ = "server:" + key_;
std::ostringstream os;
int rpc_magic = kRPCMagic;
os.write(reinterpret_cast<char*>(&rpc_magic), sizeof(rpc_magic));
int keylen = static_cast<int>(key_.length());
os.write(reinterpret_cast<char*>(&keylen), sizeof(keylen));
os.write(key_.c_str(), key_.length());
initialized_ = false;
initBytes_ = os.str();
initSendPtr_ = 0;
// Initialize the network.
CFReadStreamRef readStream;
CFWriteStreamRef writeStream;
CFStreamCreatePairWithSocketToHost(
NULL,
(__bridge CFStringRef) self.proxyURL.text,
[self.proxyPort.text intValue],
&readStream, &writeStream);
inputStream_ = (__bridge_transfer NSInputStream *)readStream;
outputStream_ = (__bridge_transfer NSOutputStream *)writeStream;
[inputStream_ setDelegate:self];
[outputStream_ setDelegate:self];
[inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
[outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
[outputStream_ open];
[inputStream_ open];
handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, key_);
CHECK(handler_ != nullptr);
self.infoText.text = @"";
self.statusLabel.text = @"Connecting...";
}
- (void)close {
NSLog(@"Closing the streams.");
[inputStream_ close];
[outputStream_ close];
[inputStream_ removeFromRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
[outputStream_ removeFromRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
[inputStream_ setDelegate:nil];
[outputStream_ setDelegate:nil];
inputStream_ = nil;
outputStream_ = nil;
handler_ = nullptr;
self.statusLabel.text = @"Disconnected";
}
- (IBAction)connect:(id)sender {
[self open];
[[self view] endEditing:YES];
}
- (IBAction)disconnect:(id)sender {
[self close];
}
@end
/*!
* Copyright (c) 2017 by Contributors
* \file main.m
*/
#import <UIKit/UIKit.h>
#import "AppDelegate.h"
int main(int argc, char * argv[]) {
@autoreleasepool {
return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class]));
}
}
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleExecutable</key>
<string>$(EXECUTABLE_NAME)</string>
<key>CFBundleIdentifier</key>
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>$(PRODUCT_NAME)</string>
<key>CFBundlePackageType</key>
<string>BNDL</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleVersion</key>
<string>1</string>
</dict>
</plist>
/*!
* Copyright (c) 2017 by Contributors
* \brief A hook to launch RPC server via xcodebuild test
* \file tvmrpcLauncher.mm
*/
#import <XCTest/XCTest.h>
#import "TVMRuntime.h"
@interface tvmrpcLauncher : XCTestCase
@end
@implementation tvmrpcLauncher
- (void)setUp {
[super setUp];
}
- (void)tearDown {
[super tearDown];
}
- (void)testRPC {
[TVMRuntime launchSyncServer];
}
@end
...@@ -390,7 +390,7 @@ class Proxy(object): ...@@ -390,7 +390,7 @@ class Proxy(object):
port=9091, port=9091,
port_end=9199, port_end=9199,
web_port=0, web_port=0,
timeout_client=10, timeout_client=240,
timeout_server=600, timeout_server=600,
index_page=None, index_page=None,
resource_files=None): resource_files=None):
......
...@@ -11,11 +11,12 @@ class TempDirectory(object): ...@@ -11,11 +11,12 @@ class TempDirectory(object):
""" """
def __init__(self): def __init__(self):
self.temp_dir = tempfile.mkdtemp() self.temp_dir = tempfile.mkdtemp()
self._rmtree = shutil.rmtree
def remove(self): def remove(self):
"""Remote the tmp dir""" """Remote the tmp dir"""
if self.temp_dir: if self.temp_dir:
shutil.rmtree(self.temp_dir) self._rmtree(self.temp_dir)
self.temp_dir = None self.temp_dir = None
def __del__(self): def __del__(self):
......
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Utility to invoke Xcode compiler toolchain""" """Utility to invoke Xcode compiler toolchain"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import os
import sys import sys
import subprocess import subprocess
from . import util from . import util
def xcrun(cmd):
"""Run xcrun and return the output.
Parameters
----------
cmd : list of str
The command sequence.
Returns
-------
out : str
The output string.
"""
cmd = ["xcrun"] + cmd
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
return out.strip()
def codesign(lib):
"""Codesign the shared libary
This is an required step for library to be loaded in
the app.
Parameters
----------
lib : The path to the library.
"""
if "TVM_IOS_CODESIGN" not in os.environ:
raise RuntimeError("Require environment variable TVM_IOS_CODESIGN "
" to be the signature")
signature = os.environ["TVM_IOS_CODESIGN"]
cmd = ["codesign", "--force", "--sign", signature]
cmd += [lib]
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Codesign error:\n"
msg += out
raise RuntimeError(msg)
def create_dylib(output, objects, arch, sdk="macosx"):
"""Create dynamic library.
Parameters
----------
output : str
The target shared library.
objects : list
List of object files.
options : str
The additional options.
arch : str
Target major architectures
sdk : str
The sdk to be used.
"""
clang = xcrun(["-sdk", sdk, "-find", "clang"])
sdk_path = xcrun(["-sdk", sdk, "--show-sdk-path"])
cmd = [clang]
cmd += ["-dynamiclib"]
cmd += ["-arch", arch]
cmd += ["-isysroot", sdk_path]
cmd += ["-o", output]
if isinstance(objects, str):
cmd += [objects]
else:
cmd += objects
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += out
raise RuntimeError(msg)
def compile_metal(code, path_target=None, sdk="macosx"): def compile_metal(code, path_target=None, sdk="macosx"):
"""Compile metal with CLI tool from env. """Compile metal with CLI tool from env.
...@@ -51,3 +142,64 @@ def compile_metal(code, path_target=None, sdk="macosx"): ...@@ -51,3 +142,64 @@ def compile_metal(code, path_target=None, sdk="macosx"):
else: else:
libbin = bytearray(open(file_target, "rb").read()) libbin = bytearray(open(file_target, "rb").read())
return libbin return libbin
def popen_test_rpc(host,
port,
key,
destination,
libs=None,
options=None):
"""Launch rpc server via xcodebuild test through another process.
Parameters
----------
host : str
The address of RPC proxy host.
port : int
The port of RPC proxy host
key : str
The key of the RPC server
destination : str
Destination device of deployment, as in xcodebuild
libs : list of str
List of files to be packed into app/Frameworks/tvm
These can be dylibs that can be loaed remoted by RPC.
options : list of str
Additional options to xcodebuild
Returns
-------
proc : Popen
The test rpc server process.
Don't do wait() on proc, since it can terminate normally.
"""
if "TVM_IOS_RPC_ROOT" in os.environ:
rpc_root = os.environ["TVM_IOS_RPC_ROOT"]
else:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
rpc_root = os.path.join(curr_path, "../../../apps/ios_rpc")
proj_path = os.path.abspath(os.path.join(rpc_root, "tvmrpc.xcodeproj"))
if not os.path.exists(proj_path):
raise RuntimeError("Cannot find tvmrpc.xcodeproj in %s," +
(" please set env TVM_IOS_RPC_ROOT correctly" % rpc_root))
with open(os.path.join(rpc_root, "rpc_config.txt"), "w") as fo:
fo.write("%s %d %s\n" % (host, port, key))
libs = libs if libs else []
for file_name in libs:
fo.write("%s\n" % file_name)
cmd = ["xcrun", "xcodebuild",
"-scheme", "tvmrpc",
"-project", proj_path,
"-destination", destination]
if options:
cmd += options
cmd += ["test"]
proc = subprocess.Popen(cmd)
return proc
...@@ -64,7 +64,10 @@ class Module(ModuleBase): ...@@ -64,7 +64,10 @@ class Module(ModuleBase):
""" """
_SaveToFile(self, file_name, fmt) _SaveToFile(self, file_name, fmt)
def export_library(self, file_name): def export_library(self,
file_name,
fcompile=None,
**kwargs):
"""Export the module and its imported device code one library. """Export the module and its imported device code one library.
This function only works on host llvm modules. This function only works on host llvm modules.
...@@ -74,6 +77,12 @@ class Module(ModuleBase): ...@@ -74,6 +77,12 @@ class Module(ModuleBase):
---------- ----------
file_name : str file_name : str
The name of the shared library. The name of the shared library.
fcompile : function(target, file_list, **kwargs), optional
Compilation function to use create dynamic library.
kwargs : dict, optiona;
Additional arguments passed to fcompile
""" """
if self.type_key == "stacktvm": if self.type_key == "stacktvm":
raise ValueError("Module[%s]: export_library requires llvm module," raise ValueError("Module[%s]: export_library requires llvm module,"
...@@ -85,18 +94,14 @@ class Module(ModuleBase): ...@@ -85,18 +94,14 @@ class Module(ModuleBase):
path_obj = temp.relpath("lib.o") path_obj = temp.relpath("lib.o")
self.save(path_obj) self.save(path_obj)
files = [path_obj] files = [path_obj]
try: is_system_lib = self.get_function("__tvm_is_system_module")()
self.get_function("__tvm_module_startup")
is_system_lib = True
except AttributeError:
is_system_lib = False
if self.imported_modules: if self.imported_modules:
path_cc = temp.relpath("devc.cc") path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f: with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib)) f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc) files.append(path_cc)
_cc.create_shared(file_name, files) fcompile = fcompile if fcompile else _cc.create_shared
fcompile(file_name, files, **kwargs)
def time_evaluator(self, func_name, ctx, number): def time_evaluator(self, func_name, ctx, number):
"""Get an evaluator that measures time cost of running function. """Get an evaluator that measures time cost of running function.
......
...@@ -36,6 +36,13 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -36,6 +36,13 @@ class LLVMModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == "__tvm_is_system_module") {
bool flag =
(mptr_->getFunction("__tvm_module_startup") != nullptr);
return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) {
* rv = flag;
});
}
if (ee_ == nullptr) LazyInitJIT(); if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ? const std::string& fname = (name == runtime::symbol::tvm_module_main ?
...@@ -118,8 +125,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -118,8 +125,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
llvm::SMDiagnostic err; llvm::SMDiagnostic err;
module_ = llvm::parseIRFile(file_name, err, *ctx_); module_ = llvm::parseIRFile(file_name, err, *ctx_);
CHECK(module_.get() != nullptr) if (module_.get() == nullptr) {
<< "Fail to load ir file " << file_name; std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load ir file " << file_name << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
std::string target = module_->getTargetTriple(); std::string target = module_->getTargetTriple();
mptr_ = module_.get(); mptr_ = module_.get();
std::ostringstream os; std::ostringstream os;
......
...@@ -135,7 +135,9 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -135,7 +135,9 @@ LoweredFunc MakeAPI(Stmt body,
n->handle_data_type = binder.def_handle_dtype(); n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0; n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted; n->is_restricted = is_restricted;
body = AttrStmt::make(
make_zero(Int(32)), attr::compute_scope,
StringImm::make(name + "_compute_"), body);
// Set device context // Set device context
if (vmap.count(device_id.get())) { if (vmap.count(device_id.get())) {
Expr node = StringImm::make("default"); Expr node = StringImm::make("default");
...@@ -149,9 +151,6 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -149,9 +151,6 @@ LoweredFunc MakeAPI(Stmt body,
Int(32), intrinsic::tvm_call_packed, Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device), {StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic))); device_type, device_id}, Call::Intrinsic)));
body = AttrStmt::make(
make_zero(Int(32)), attr::compute_scope,
StringImm::make(name + "_compute_"), body);
body = Block::make(set_device, body); body = Block::make(set_device, body);
} }
n->body = MergeNest( n->body = MergeNest(
......
...@@ -241,7 +241,7 @@ void* TVMBackendAllocWorkspace(int device_type, ...@@ -241,7 +241,7 @@ void* TVMBackendAllocWorkspace(int device_type,
TVMContext ctx; TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, size); return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size));
} }
int TVMBackendFreeWorkspace(int device_type, int TVMBackendFreeWorkspace(int device_type,
...@@ -437,8 +437,8 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -437,8 +437,8 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
<< "Can not copy across different ctx types directly"; << "Can not copy across different ctx types directly";
} }
DeviceAPIManager::Get(ctx)->CopyDataFromTo( DeviceAPIManager::Get(ctx)->CopyDataFromTo(
from->data, from->byte_offset, from->data, static_cast<size_t>(from->byte_offset),
to->data, to->byte_offset, to->data, static_cast<size_t>(to->byte_offset),
from_size, from->ctx, to->ctx, stream); from_size, from->ctx, to->ctx, stream);
API_END(); API_END();
} }
...@@ -455,7 +455,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle, ...@@ -455,7 +455,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
<< "TVMArrayCopyFromBytes: size mismatch"; << "TVMArrayCopyFromBytes: size mismatch";
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo( DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
data, 0, data, 0,
handle->data, handle->byte_offset, handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, nullptr); nbytes, cpu_ctx, handle->ctx, nullptr);
API_END(); API_END();
} }
...@@ -471,7 +471,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -471,7 +471,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes)
<< "TVMArrayCopyToBytes: size mismatch"; << "TVMArrayCopyToBytes: size mismatch";
DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo( DeviceAPIManager::Get(handle->ctx)->CopyDataFromTo(
handle->data, handle->byte_offset, handle->data, static_cast<size_t>(handle->byte_offset),
data, 0, data, 0,
nbytes, handle->ctx, cpu_ctx, nullptr); nbytes, handle->ctx, cpu_ctx, nullptr);
API_END(); API_END();
......
...@@ -91,7 +91,8 @@ class DSOModuleNode final : public ModuleNode { ...@@ -91,7 +91,8 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) { void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name; << "Failed to load dynamic shared library " << name
<< " " << dlerror();
} }
void* GetSymbol(const char* name) { void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name); return dlsym(lib_handle_, name);
......
...@@ -81,7 +81,7 @@ int GetWarpSize(id<MTLDevice> dev) { ...@@ -81,7 +81,7 @@ int GetWarpSize(id<MTLDevice> dev) {
newComputePipelineStateWithFunction:f newComputePipelineStateWithFunction:f
error:&error_msg]; error:&error_msg];
CHECK(state != nil) << [[error_msg localizedDescription] UTF8String]; CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
return state.threadExecutionWidth; return static_cast<int>(state.threadExecutionWidth);
} }
MetalWorkspace::~MetalWorkspace() { MetalWorkspace::~MetalWorkspace() {
...@@ -99,6 +99,12 @@ void MetalWorkspace::Init() { ...@@ -99,6 +99,12 @@ void MetalWorkspace::Init() {
if (initialized_) return; if (initialized_) return;
initialized_ = true; initialized_ = true;
if (devices.size() != 0) return; if (devices.size() != 0) return;
#if TARGET_OS_IPHONE
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
#else
NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices(); NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) { for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i]; id<MTLDevice> d = [devs objectAtIndex:i];
...@@ -108,6 +114,7 @@ void MetalWorkspace::Init() { ...@@ -108,6 +114,7 @@ void MetalWorkspace::Init() {
<< ", name=" << d.name; << ", name=" << d.name;
warp_size.push_back(GetWarpSize(d)); warp_size.push_back(GetWarpSize(d));
} }
#endif
} }
void MetalWorkspace::SetDevice(TVMContext ctx) { void MetalWorkspace::SetDevice(TVMContext ctx) {
...@@ -122,6 +129,7 @@ void* MetalWorkspace::AllocDataSpace( ...@@ -122,6 +129,7 @@ void* MetalWorkspace::AllocDataSpace(
id<MTLBuffer> buf = [ id<MTLBuffer> buf = [
dev newBufferWithLength:size dev newBufferWithLength:size
options:MTLResourceStorageModePrivate]; options:MTLResourceStorageModePrivate];
CHECK(buf != nil);
return (__bridge void*)([buf retain]); return (__bridge void*)([buf retain]);
} }
...@@ -144,13 +152,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -144,13 +152,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
if (ctx_from.device_type == kCPU) ctx = ctx_to; if (ctx_from.device_type == kCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx); id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer]; id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
int from_dev_type = static_cast<int>(ctx_from.device_type); int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type); int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kMetal && to_dev_type == kMetal) { if (from_dev_type == kMetal && to_dev_type == kMetal) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id) CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Metal disallow cross device copy."; << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from) [encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
sourceOffset:from_offset sourceOffset:from_offset
toBuffer:(__bridge id<MTLBuffer>)(to) toBuffer:(__bridge id<MTLBuffer>)(to)
...@@ -164,6 +172,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -164,6 +172,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
if (from_buf.storageMode != MTLStorageModeShared) { if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal() id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
->GetTempBuffer(ctx_from, size); ->GetTempBuffer(ctx_from, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:from_buf [encoder copyFromBuffer:from_buf
sourceOffset:from_offset sourceOffset:from_offset
toBuffer:temp toBuffer:temp
...@@ -188,6 +197,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -188,6 +197,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
memcpy([temp contents], memcpy([temp contents],
static_cast<const char*>(from) + from_offset, static_cast<const char*>(from) + from_offset,
size); size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:temp [encoder copyFromBuffer:temp
sourceOffset:0 sourceOffset:0
toBuffer:to_buf toBuffer:to_buf
......
...@@ -66,7 +66,7 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -66,7 +66,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
return ""; return "";
} }
} }
// get a CUfunction from primary context in device_id // get a from primary context in device_id
id<MTLComputePipelineState> GetPipelineState( id<MTLComputePipelineState> GetPipelineState(
size_t device_id, const std::string& func_name) { size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get(); metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
...@@ -194,7 +194,7 @@ class MetalWrappedFunc { ...@@ -194,7 +194,7 @@ class MetalWrappedFunc {
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder]; id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]]; [encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) { for (size_t i = 0; i < num_buffer_args_; ++i) {
void* buf = args[i]; void* buf = args[static_cast<int>(i)];
[encoder setBuffer:(__bridge id<MTLBuffer>)(buf) offset:0 atIndex:i]; [encoder setBuffer:(__bridge id<MTLBuffer>)(buf) offset:0 atIndex:i];
} }
if (num_pack_args_ != 0) { if (num_pack_args_ != 0) {
......
...@@ -19,7 +19,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -19,7 +19,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
nbytes |= (c & 0xffUL) << (i * 8); nbytes |= (c & 0xffUL) << (i * 8);
} }
dmlc::MemoryFixedSizeStream fs( dmlc::MemoryFixedSizeStream fs(
const_cast<char*>(mblob + sizeof(nbytes)), nbytes); const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes));
dmlc::Stream* stream = &fs; dmlc::Stream* stream = &fs;
uint64_t size; uint64_t size;
CHECK(stream->Read(&size)); CHECK(stream->Read(&size));
......
...@@ -32,7 +32,7 @@ union ArgUnion { ...@@ -32,7 +32,7 @@ union ArgUnion {
* *
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments that wish to get from * \param arg_types The arguments that wish to get from
* \tparam T the function type * \tparam F the function type
* *
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
...@@ -43,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types); ...@@ -43,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
* *
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args) * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments that wish to get from * \param arg_types The arguments that wish to get from
* \tparam T the function type * \tparam F the function type
* *
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
......
...@@ -36,12 +36,5 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download") ...@@ -36,12 +36,5 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.download")
*rv = arr; *rv = arr;
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = RPCGetPath(args[0]);
*rv = Module::LoadFromFile(file_name, "");
LOG(INFO) << "Load module from " << file_name << " ...";
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -946,7 +946,11 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { ...@@ -946,7 +946,11 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0]; void* mhandle = args[0];
PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction( PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction(
args[1], false); args[1], false);
if (pf != nullptr) {
*rv = static_cast<void*>(new PackedFunc(pf)); *rv = static_cast<void*>(new PackedFunc(pf));
} else {
*rv = nullptr;
}
} }
void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
......
...@@ -39,7 +39,8 @@ class SockChannel final : public RPCChannel { ...@@ -39,7 +39,8 @@ class SockChannel final : public RPCChannel {
common::TCPSocket sock_; common::TCPSocket sock_;
}; };
Module RPCConnect(std::string url, int port, std::string key) { std::shared_ptr<RPCSession>
RPCConnect(std::string url, int port, std::string key) {
common::TCPSocket sock; common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port); common::SockAddr addr(url.c_str(), port);
sock.Create(); sock.Create();
...@@ -47,8 +48,6 @@ Module RPCConnect(std::string url, int port, std::string key) { ...@@ -47,8 +48,6 @@ Module RPCConnect(std::string url, int port, std::string key) {
<< "Connect to " << addr.AsString() << " failed"; << "Connect to " << addr.AsString() << " failed";
// hand shake // hand shake
std::ostringstream os; std::ostringstream os;
os << "client:" << key;
key = os.str();
int code = kRPCMagic; int code = kRPCMagic;
int keylen = static_cast<int>(key.length()); int keylen = static_cast<int>(key.length());
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code)); CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
...@@ -64,15 +63,16 @@ Module RPCConnect(std::string url, int port, std::string key) { ...@@ -64,15 +63,16 @@ Module RPCConnect(std::string url, int port, std::string key) {
} else if (code == kRPCMagic + 1) { } else if (code == kRPCMagic + 1) {
sock.Close(); sock.Close();
LOG(FATAL) << "URL " << url << ":" << port LOG(FATAL) << "URL " << url << ":" << port
<< " server already have client key=" << key; << " server already have key=" << key;
} else if (code != kRPCMagic) { } else if (code != kRPCMagic) {
sock.Close(); sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
} }
return CreateRPCModule( return RPCSession::Create(std::unique_ptr<SockChannel>(new SockChannel(sock)), key);
RPCSession::Create( }
std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockClient")); Module RPCClientConnect(std::string url, int port, std::string key) {
return CreateRPCModule(RPCConnect(url, port, "client:" + key));
} }
void RPCServerLoop(int sockfd) { void RPCServerLoop(int sockfd) {
...@@ -85,7 +85,7 @@ void RPCServerLoop(int sockfd) { ...@@ -85,7 +85,7 @@ void RPCServerLoop(int sockfd) {
TVM_REGISTER_GLOBAL("contrib.rpc._Connect") TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1], args[2]); *rv = RPCClientConnect(args[0], args[1], args[2]);
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop") TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
......
...@@ -41,5 +41,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath") ...@@ -41,5 +41,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
static RPCEnv env; static RPCEnv env;
*rv = env.GetPath(args[0]); *rv = env.GetPath(args[0]);
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = "/rpc/" + args[0].operator std::string();
*rv = Module::LoadFromFile(file_name, "");
LOG(INFO) << "Load module from " << file_name << " ...";
});
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment