Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5bd4afee
Commit
5bd4afee
authored
Jul 24, 2018
by
Yizhi Liu
Committed by
Tianqi Chen
Jul 24, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[tvm4j] add GraphRuntime (#1472)
parent
05afac09
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
601 additions
and
42 deletions
+601
-42
jvm/core/src/main/java/ml/dmlc/tvm/Function.java
+2
-2
jvm/core/src/main/java/ml/dmlc/tvm/Module.java
+8
-2
jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java
+12
-2
jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java
+1
-2
jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java
+8
-8
jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java
+34
-0
jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java
+170
-0
jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphRuntime.java
+121
-0
jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java
+5
-0
jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
+2
-2
jvm/core/src/main/java/ml/dmlc/tvm/rpc/TVMRemoteContext.java
+30
-0
jvm/core/src/test/java/ml/dmlc/tvm/TestUtils.java
+26
-0
jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java
+114
-0
jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java
+10
-24
jvm/core/src/test/scripts/test_graph_runtime.py
+47
-0
jvm/native/src/main/native/jni_helper_func.h
+10
-0
tests/scripts/task_java_unittest.sh
+1
-0
No files found.
jvm/core/src/main/java/ml/dmlc/tvm/Function.java
View file @
5bd4afee
...
...
@@ -109,8 +109,7 @@ public class Function extends TVMValue {
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
@Override
public
void
release
()
{
...
...
@@ -269,6 +268,7 @@ public class Function extends TVMValue {
case
BYTES:
Base
.
_LIB
.
tvmFuncPushArgBytes
(
tvmArg
.
asBytes
());
break
;
case
HANDLE:
case
ARRAY_HANDLE:
case
MODULE_HANDLE:
case
FUNC_HANDLE:
...
...
jvm/core/src/main/java/ml/dmlc/tvm/Module.java
View file @
5bd4afee
...
...
@@ -72,8 +72,7 @@ public class Module extends TVMValue {
/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
@Override
public
void
release
()
{
...
...
@@ -123,6 +122,13 @@ public class Module extends TVMValue {
}
/**
* @return type key of the module.
*/
public
String
typeKey
()
{
return
getApi
(
"_GetTypeKey"
).
pushArg
(
this
).
invoke
().
asString
();
}
/**
* Load module from file.
* @param path The path to the module file.
* @param fmt The format of the file,
...
...
jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java
View file @
5bd4afee
...
...
@@ -27,10 +27,12 @@ import java.util.List;
*/
public
class
NDArray
extends
NDArrayBase
{
private
final
TVMType
dtype
;
private
final
TVMContext
context
;
NDArray
(
long
handle
,
boolean
isView
,
TVMType
dtype
)
{
NDArray
(
long
handle
,
boolean
isView
,
TVMType
dtype
,
TVMContext
ctx
)
{
super
(
handle
,
isView
);
this
.
dtype
=
dtype
;
this
.
context
=
ctx
;
}
@Override
protected
void
finalize
()
throws
Throwable
{
...
...
@@ -362,6 +364,14 @@ public class NDArray extends NDArrayBase {
}
/**
* Get the context of current array.
* @return the context.
*/
public
TVMContext
ctx
()
{
return
context
;
}
/**
* Create an empty array given shape, type and device.
* @param shape The shape of the array.
* @param dtype The data type of the array.
...
...
@@ -373,7 +383,7 @@ public class NDArray extends NDArrayBase {
Base
.
checkCall
(
Base
.
_LIB
.
tvmArrayAlloc
(
shape
,
dtype
.
typeCode
,
dtype
.
bits
,
dtype
.
lanes
,
ctx
.
deviceType
,
ctx
.
deviceId
,
refHandle
));
return
new
NDArray
(
refHandle
.
value
,
false
,
dtype
);
return
new
NDArray
(
refHandle
.
value
,
false
,
dtype
,
ctx
);
}
/**
...
...
jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java
View file @
5bd4afee
...
...
@@ -57,8 +57,7 @@ public class NDArrayBase extends TVMValue {
/**
* Release the NDArray memory.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
public
void
release
()
{
...
...
jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java
View file @
5bd4afee
...
...
@@ -37,16 +37,16 @@ public class TVMType {
this
.
lanes
=
lanes
;
int
bitsTemp
=
0
;
if
(
typeStr
.
startsWith
(
"int"
))
{
typeCode
=
0
;
typeCode
=
INT
;
bitsTemp
=
Integer
.
parseInt
(
typeStr
.
substring
(
3
));
}
else
if
(
typeStr
.
startsWith
(
"uint"
))
{
typeCode
=
1
;
typeCode
=
UINT
;
bitsTemp
=
Integer
.
parseInt
(
typeStr
.
substring
(
4
));
}
else
if
(
typeStr
.
startsWith
(
"float"
))
{
typeCode
=
2
;
typeCode
=
FLOAT
;
bitsTemp
=
Integer
.
parseInt
(
typeStr
.
substring
(
5
));
}
else
if
(
typeStr
.
startsWith
(
"handle"
))
{
typeCode
=
4
;
typeCode
=
HANDLE
;
bitsTemp
=
64
;
}
else
{
throw
new
IllegalArgumentException
(
"Do not know how to handle type "
+
typeStr
);
...
...
@@ -78,16 +78,16 @@ public class TVMType {
@Override
public
String
toString
()
{
String
typeCodeStr
;
switch
(
typeCode
)
{
case
0
:
case
INT
:
typeCodeStr
=
"int"
;
break
;
case
1
:
case
UINT
:
typeCodeStr
=
"uint"
;
break
;
case
2
:
case
FLOAT
:
typeCodeStr
=
"float"
;
break
;
case
4
:
case
HANDLE
:
typeCodeStr
=
"handle"
;
break
;
default
:
...
...
jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java
0 → 100644
View file @
5bd4afee
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package
ml
.
dmlc
.
tvm
;
/**
* Java class related to TVM handles (TypeCode.HANDLE)
*/
public
class
TVMValueHandle
extends
TVMValue
{
public
final
long
value
;
public
TVMValueHandle
(
long
value
)
{
super
(
TypeCode
.
HANDLE
);
this
.
value
=
value
;
}
@Override
public
long
asHandle
()
{
return
value
;
}
}
jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java
0 → 100644
View file @
5bd4afee
package
ml
.
dmlc
.
tvm
.
contrib
;
import
ml.dmlc.tvm.Function
;
import
ml.dmlc.tvm.Module
;
import
ml.dmlc.tvm.NDArray
;
import
ml.dmlc.tvm.TVMContext
;
/**
* Wrapper runtime module.
* This is a thin wrapper of the underlying TVM module.
* you can also directly call set_input, run, and get_output
* of underlying module functions.
*/
public
class
GraphModule
{
private
Module
module
;
private
TVMContext
ctx
;
private
Function
fsetInput
;
private
Function
frun
;
private
Function
fgetOutput
;
private
Function
fgetInput
;
private
Function
fdebugGetOutput
;
private
Function
floadParams
;
GraphModule
(
Module
module
,
TVMContext
ctx
)
{
this
.
module
=
module
;
this
.
ctx
=
ctx
;
fsetInput
=
module
.
getFunction
(
"set_input"
);
frun
=
module
.
getFunction
(
"run"
);
fgetInput
=
module
.
getFunction
(
"get_input"
);
fgetOutput
=
module
.
getFunction
(
"get_output"
);
try
{
fdebugGetOutput
=
module
.
getFunction
(
"debug_get_output"
);
}
catch
(
IllegalArgumentException
ignored
)
{
// ignore
}
floadParams
=
module
.
getFunction
(
"load_params"
);
}
/**
* Release the GraphModule.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
public
void
release
()
{
fsetInput
.
release
();
frun
.
release
();
fgetInput
.
release
();
fgetOutput
.
release
();
if
(
fdebugGetOutput
!=
null
)
{
fdebugGetOutput
.
release
();
}
floadParams
.
release
();
module
.
release
();
}
/**
* Set inputs to the module.
* @param key The input key.
* @param value The input value
* @return self.
*/
public
GraphModule
setInput
(
String
key
,
NDArray
value
)
{
NDArray
input
=
value
;
if
(!
value
.
ctx
().
equals
(
ctx
))
{
input
=
NDArray
.
empty
(
value
.
shape
(),
ctx
);
value
.
copyTo
(
input
);
}
fsetInput
.
pushArg
(
key
).
pushArg
(
input
).
invoke
();
return
this
;
}
/**
* Set inputs to the module
* @param key The input key.
* @param value The input value.
* @return self.
*/
public
GraphModule
setInput
(
int
key
,
NDArray
value
)
{
NDArray
input
=
value
;
if
(!
value
.
ctx
().
equals
(
ctx
))
{
input
=
NDArray
.
empty
(
value
.
shape
(),
ctx
);
value
.
copyTo
(
input
);
}
fsetInput
.
pushArg
(
key
).
pushArg
(
input
).
invoke
();
return
this
;
}
/**
* Run forward execution of the graph.
* @return self.
*/
public
GraphModule
run
()
{
frun
.
invoke
();
return
this
;
}
/**
* Get index-th input to out.
* @param index The input index.
* @param out The output array container.
* @return out.
*/
public
NDArray
getInput
(
int
index
,
NDArray
out
)
{
fgetInput
.
pushArg
(
index
).
pushArg
(
out
).
invoke
();
return
out
;
}
/**
* Get index-th output to out.
* @param index The output index.
* @param out The output array container.
* @return out.
*/
public
NDArray
getOutput
(
int
index
,
NDArray
out
)
{
fgetOutput
.
pushArg
(
index
).
pushArg
(
out
).
invoke
();
return
out
;
}
/**
* Run graph up to node and get the output to out.
* @param node The node name.
* @param out The output array container.
* @return out.
*/
public
NDArray
debugGetOutput
(
String
node
,
NDArray
out
)
{
if
(
fdebugGetOutput
!=
null
)
{
fdebugGetOutput
.
pushArg
(
node
).
pushArg
(
out
).
invoke
();
}
else
{
throw
new
RuntimeException
(
"Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0"
);
}
return
out
;
}
/**
* Run graph up to node and get the output to out.
* @param node The node index.
* @param out The output array container.
* @return out.
*/
public
NDArray
debugGetOutput
(
int
node
,
NDArray
out
)
{
if
(
fdebugGetOutput
!=
null
)
{
fdebugGetOutput
.
pushArg
(
node
).
pushArg
(
out
).
invoke
();
}
else
{
throw
new
RuntimeException
(
"Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0"
);
}
return
out
;
}
/**
* Load parameters from serialized byte array of parameter dict.
* @param params The serialized parameter.
* @return self.
*/
public
GraphModule
loadParams
(
byte
[]
params
)
{
floadParams
.
pushArg
(
params
).
invoke
();
return
this
;
}
/**
* Get internal module function.
* @param key The key to the module.
* @return The function.
* @throws IllegalArgumentException if function does not exist.
*/
public
Function
getFunction
(
String
key
)
{
return
module
.
getFunction
(
key
);
}
}
jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphRuntime.java
0 → 100644
View file @
5bd4afee
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package
ml
.
dmlc
.
tvm
.
contrib
;
import
ml.dmlc.tvm.Function
;
import
ml.dmlc.tvm.Module
;
import
ml.dmlc.tvm.TVMContext
;
import
ml.dmlc.tvm.TVMValue
;
import
ml.dmlc.tvm.rpc.RPC
;
import
ml.dmlc.tvm.rpc.RPCSession
;
import
ml.dmlc.tvm.rpc.TVMRemoteContext
;
import
java.lang.reflect.Field
;
import
java.lang.reflect.InvocationTargetException
;
import
java.lang.reflect.Method
;
public
class
GraphRuntime
{
/**
* Create a runtime executor module given a graph and module.
* @param graphJson The graph deployed in json format output by nnvm graph.
* @param libmod The module of the corresponding function.
* @param ctx The local or remote context to deploy the module.
* @return Runtime graph module that can be used to execute the graph.
*/
public
static
GraphModule
create
(
String
graphJson
,
Module
libmod
,
TVMContext
ctx
)
{
Module
graphModule
=
null
;
if
(
ctx
.
deviceType
>=
RPC
.
RPC_SESS_MASK
)
{
if
(!(
ctx
instanceof
TVMRemoteContext
))
{
throw
new
IllegalArgumentException
(
"Looks like you are using remote context with no RPCSession bind."
+
"Use session.context instead."
);
}
RPCSession
rpcSession
=
((
TVMRemoteContext
)
ctx
).
rpcSession
;
// check arguments
if
(!
"rpc"
.
equals
(
libmod
.
typeKey
()))
{
throw
new
IllegalArgumentException
(
"libmod.typeKey != rpc"
);
}
final
int
sessIndex
=
(
int
)
((
Function
)
reflectionStaticCall
(
RPC
.
class
,
"getApi"
,
"_SessTableIndex"
))
.
pushArg
(
libmod
).
invoke
().
asLong
();
if
(
sessIndex
!=
(
Integer
)
reflectionGetField
(
rpcSession
,
"tblIndex"
))
{
throw
new
IllegalArgumentException
(
String
.
format
(
"libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d"
,
sessIndex
,
reflectionGetField
(
rpcSession
,
"tblIndex"
)));
}
Function
rpcModuleHandle
=
(
Function
)
reflectionStaticCall
(
RPC
.
class
,
"getApi"
,
"_ModuleHandle"
);
if
(
rpcModuleHandle
==
null
)
{
throw
new
RuntimeException
(
"Cannot find global function tvm.rpc._ModuleHandle."
+
"Did you compile tvm_runtime with the correct version?"
);
}
Function
fcreate
=
Function
.
getFunction
(
"tvm.graph_runtime.remote_create"
);
if
(
fcreate
==
null
)
{
throw
new
RuntimeException
(
"Cannot find global function tvm.graph_runtime.remote_create."
+
"Did you compile tvm_runtime with correct version?"
);
}
TVMValue
hmod
=
rpcModuleHandle
.
pushArg
(
libmod
).
invoke
();
graphModule
=
fcreate
.
call
(
graphJson
,
hmod
,
ctx
.
deviceType
%
RPC
.
RPC_SESS_MASK
,
ctx
.
deviceId
).
asModule
();
}
else
{
Function
fcreate
=
Function
.
getFunction
(
"tvm.graph_runtime.create"
);
if
(
fcreate
==
null
)
{
throw
new
RuntimeException
(
"Cannot find global function tvm.graph_runtime.create."
+
"Did you compile tvm_runtime with correct version?"
);
}
graphModule
=
fcreate
.
pushArg
(
graphJson
)
.
pushArg
(
libmod
).
pushArg
(
ctx
.
deviceType
).
pushArg
(
ctx
.
deviceId
)
.
invoke
().
asModule
();
}
return
new
GraphModule
(
graphModule
,
ctx
);
}
private
static
Object
reflectionGetField
(
Object
obj
,
String
fieldName
)
{
try
{
Field
field
=
obj
.
getClass
().
getDeclaredField
(
fieldName
);
field
.
setAccessible
(
true
);
return
field
.
get
(
obj
);
}
catch
(
NoSuchFieldException
e
)
{
throw
new
RuntimeException
(
e
);
}
catch
(
IllegalAccessException
e
)
{
throw
new
RuntimeException
(
e
);
}
}
private
static
Object
reflectionStaticCall
(
Class
<?>
clazz
,
String
methodName
,
Object
...
args
)
{
Class
<?>[]
types
=
new
Class
<?>[
args
.
length
];
for
(
int
i
=
0
;
i
<
args
.
length
;
++
i
)
{
types
[
i
]
=
args
[
i
].
getClass
();
}
try
{
Method
method
=
clazz
.
getDeclaredMethod
(
methodName
,
types
);
method
.
setAccessible
(
true
);
return
method
.
invoke
(
null
,
args
);
}
catch
(
NoSuchMethodException
e
)
{
throw
new
RuntimeException
(
e
);
}
catch
(
IllegalAccessException
e
)
{
throw
new
RuntimeException
(
e
);
}
catch
(
InvocationTargetException
e
)
{
throw
new
RuntimeException
(
e
);
}
}
}
jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java
View file @
5bd4afee
...
...
@@ -44,6 +44,11 @@ public class RPC {
}
};
/**
* Get internal function starts with namespace tvm.rpc.
* @param name function name.
* @return the function, null if not exists.
*/
static
Function
getApi
(
String
name
)
{
Function
func
=
apiFuncs
.
get
().
get
(
name
);
if
(
func
==
null
)
{
...
...
jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
View file @
5bd4afee
...
...
@@ -60,7 +60,7 @@ public class RPCSession {
public
TVMContext
context
(
String
devType
,
int
devId
)
{
TVMContext
ctx
=
new
TVMContext
(
devType
,
devId
);
int
encode
=
(
tblIndex
+
1
)
*
RPC
.
RPC_SESS_MASK
;
return
new
TVM
Context
(
ctx
.
deviceType
+
encode
,
devId
);
return
new
TVM
RemoteContext
(
ctx
.
deviceType
+
encode
,
devId
,
this
);
}
/**
...
...
@@ -80,7 +80,7 @@ public class RPCSession {
*/
public
TVMContext
context
(
int
devType
,
int
devId
)
{
int
encode
=
(
tblIndex
+
1
)
*
RPC
.
RPC_SESS_MASK
;
return
new
TVM
Context
(
devType
+
encode
,
devId
);
return
new
TVM
RemoteContext
(
devType
+
encode
,
devId
,
this
);
}
/**
...
...
jvm/core/src/main/java/ml/dmlc/tvm/rpc/TVMRemoteContext.java
0 → 100644
View file @
5bd4afee
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package
ml
.
dmlc
.
tvm
.
rpc
;
import
ml.dmlc.tvm.TVMContext
;
// always related to RPCSession. Cannot construct by users.
public
class
TVMRemoteContext
extends
TVMContext
{
public
final
RPCSession
rpcSession
;
TVMRemoteContext
(
int
deviceType
,
int
deviceId
,
RPCSession
rpcSession
)
{
super
(
deviceType
,
deviceId
);
this
.
rpcSession
=
rpcSession
;
}
}
jvm/core/src/test/java/ml/dmlc/tvm/TestUtils.java
0 → 100644
View file @
5bd4afee
package
ml
.
dmlc
.
tvm
;
import
ml.dmlc.tvm.rpc.Server
;
import
java.io.IOException
;
public
class
TestUtils
{
public
static
class
RefInt
{
public
int
value
;
}
public
static
Server
startServer
(
RefInt
portRef
)
{
Server
server
=
null
;
int
port
=
9981
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
try
{
server
=
new
Server
(
port
+
i
);
server
.
start
();
portRef
.
value
=
port
+
i
;
return
server
;
}
catch
(
IOException
e
)
{
}
}
throw
new
RuntimeException
(
"Cannot find an available port."
);
}
}
jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java
0 → 100644
View file @
5bd4afee
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package
ml
.
dmlc
.
tvm
.
contrib
;
import
ml.dmlc.tvm.*
;
import
ml.dmlc.tvm.rpc.Client
;
import
ml.dmlc.tvm.rpc.RPCSession
;
import
ml.dmlc.tvm.rpc.Server
;
import
org.junit.BeforeClass
;
import
org.junit.Test
;
import
org.slf4j.Logger
;
import
org.slf4j.LoggerFactory
;
import
java.io.File
;
import
java.io.IOException
;
import
java.util.Scanner
;
import
static
org
.
junit
.
Assert
.
assertArrayEquals
;
public
class
GraphRuntimeTest
{
private
final
Logger
logger
=
LoggerFactory
.
getLogger
(
GraphRuntime
.
class
);
private
static
String
loadingDir
;
@BeforeClass
public
static
void
beforeClass
()
{
loadingDir
=
System
.
getProperty
(
"test.tempdir"
);
}
@Test
public
void
test_add_one_local
()
throws
IOException
{
Module
libmod
=
Module
.
load
(
loadingDir
+
File
.
separator
+
"graph_addone_lib.so"
);
String
graphJson
=
new
Scanner
(
new
File
(
loadingDir
+
File
.
separator
+
"graph_addone.json"
))
.
useDelimiter
(
"\\Z"
).
next
();
TVMContext
ctx
=
TVMContext
.
cpu
();
GraphModule
graph
=
GraphRuntime
.
create
(
graphJson
,
libmod
,
ctx
);
long
[]
shape
=
new
long
[]{
4
};
NDArray
arr
=
NDArray
.
empty
(
shape
,
ctx
);
arr
.
copyFrom
(
new
float
[]{
1
f
,
2
f
,
3
f
,
4
f
});
NDArray
out
=
NDArray
.
empty
(
shape
,
ctx
);
graph
.
setInput
(
"x"
,
arr
).
run
();
graph
.
getOutput
(
0
,
out
);
assertArrayEquals
(
new
float
[]{
2
f
,
3
f
,
4
f
,
5
f
},
out
.
asFloatArray
(),
1
e
-
3
f
);
arr
.
release
();
out
.
release
();
graph
.
release
();
}
@Test
public
void
test_add_one_remote
()
throws
IOException
{
if
(!
Module
.
enabled
(
"rpc"
))
{
logger
.
warn
(
"RPC is not enabled. Skip."
);
return
;
}
String
libPath
=
loadingDir
+
File
.
separator
+
"graph_addone_lib.so"
;
String
graphJson
=
new
Scanner
(
new
File
(
loadingDir
+
File
.
separator
+
"graph_addone.json"
))
.
useDelimiter
(
"\\Z"
).
next
();
TestUtils
.
RefInt
port
=
new
TestUtils
.
RefInt
();
Server
server
=
null
;
try
{
server
=
TestUtils
.
startServer
(
port
);
RPCSession
remote
=
Client
.
connect
(
"localhost"
,
port
.
value
);
TVMContext
ctx
=
remote
.
cpu
();
remote
.
upload
(
new
File
(
libPath
));
Module
mlib
=
remote
.
loadModule
(
"graph_addone_lib.so"
);
GraphModule
graph
=
GraphRuntime
.
create
(
graphJson
,
mlib
,
ctx
);
long
[]
shape
=
new
long
[]{
4
};
NDArray
arr
=
NDArray
.
empty
(
shape
,
ctx
);
arr
.
copyFrom
(
new
float
[]{
1
f
,
2
f
,
3
f
,
4
f
});
NDArray
out
=
NDArray
.
empty
(
shape
,
ctx
);
graph
.
setInput
(
"x"
,
arr
).
run
();
graph
.
getOutput
(
0
,
out
);
assertArrayEquals
(
new
float
[]{
2
f
,
3
f
,
4
f
,
5
f
},
out
.
asFloatArray
(),
1
e
-
3
f
);
arr
.
release
();
out
.
release
();
graph
.
release
();
}
finally
{
if
(
server
!=
null
)
{
server
.
terminate
();
}
}
}
}
jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java
View file @
5bd4afee
...
...
@@ -20,36 +20,21 @@ package ml.dmlc.tvm.rpc;
import
ml.dmlc.tvm.Function
;
import
ml.dmlc.tvm.Module
;
import
ml.dmlc.tvm.TVMValue
;
import
ml.dmlc.tvm.TestUtils
;
import
org.junit.Ignore
;
import
org.junit.Test
;
import
java.io.IOException
;
import
org.slf4j.Logger
;
import
org.slf4j.LoggerFactory
;
import
static
org
.
junit
.
Assert
.
assertEquals
;
public
class
RPCTest
{
static
class
RefInt
{
public
int
value
;
}
private
static
Server
startServer
(
RefInt
portRef
)
{
Server
server
=
null
;
int
port
=
9981
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
try
{
server
=
new
Server
(
port
+
i
);
server
.
start
();
portRef
.
value
=
port
+
i
;
return
server
;
}
catch
(
IOException
e
)
{
}
}
throw
new
RuntimeException
(
"Cannot find an available port."
);
}
private
final
Logger
logger
=
LoggerFactory
.
getLogger
(
RPCTest
.
class
);
@Test
public
void
test_addone
()
{
if
(!
Module
.
enabled
(
"rpc"
))
{
logger
.
warn
(
"RPC is not enabled. Skip."
);
return
;
}
Function
.
register
(
"test.rpc.addone"
,
new
Function
.
Callback
()
{
...
...
@@ -58,10 +43,10 @@ public class RPCTest {
}
});
RefInt
port
=
new
RefInt
();
TestUtils
.
RefInt
port
=
new
TestUtils
.
RefInt
();
Server
server
=
null
;
try
{
server
=
startServer
(
port
);
server
=
TestUtils
.
startServer
(
port
);
RPCSession
client
=
Client
.
connect
(
"localhost"
,
port
.
value
);
Function
func
=
client
.
getFunction
(
"test.rpc.addone"
);
assertEquals
(
11L
,
func
.
call
(
10
).
asLong
());
...
...
@@ -75,6 +60,7 @@ public class RPCTest {
@Test
public
void
test_strcat
()
{
if
(!
Module
.
enabled
(
"rpc"
))
{
logger
.
warn
(
"RPC is not enabled. Skip."
);
return
;
}
Function
.
register
(
"test.rpc.strcat"
,
new
Function
.
Callback
()
{
...
...
@@ -83,10 +69,10 @@ public class RPCTest {
}
});
RefInt
port
=
new
RefInt
();
TestUtils
.
RefInt
port
=
new
TestUtils
.
RefInt
();
Server
server
=
null
;
try
{
server
=
startServer
(
port
);
server
=
TestUtils
.
startServer
(
port
);
RPCSession
client
=
Client
.
connect
(
"localhost"
,
port
.
value
);
Function
func
=
client
.
getFunction
(
"test.rpc.strcat"
);
assertEquals
(
"abc:11"
,
func
.
call
(
"abc"
,
11L
).
asString
());
...
...
jvm/core/src/test/scripts/test_graph_runtime.py
0 → 100644
View file @
5bd4afee
import
os
import
tvm
import
json
from
tvm.contrib
import
graph_runtime
def
dump_graph_lib
(
target_dir
):
dim
=
4
A
=
tvm
.
placeholder
((
dim
,),
name
=
'A'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
1.0
,
name
=
'B'
)
sched
=
tvm
.
create_schedule
(
B
.
op
)
node0
=
{
"op"
:
"null"
,
"name"
:
"x"
,
"inputs"
:
[]}
node1
=
{
"op"
:
"tvm_op"
,
"name"
:
"add"
,
"inputs"
:
[[
0
,
0
,
0
]],
"attrs"
:
{
"func_name"
:
"myadd"
,
"flatten_data"
:
"1"
,
"num_inputs"
:
"1"
,
"num_outputs"
:
"1"
}}
nodes
=
[
node0
,
node1
]
arg_nodes
=
[
0
]
node_row_ptr
=
[
0
,
1
,
2
]
outputs
=
[[
1
,
0
,
0
]]
shape
=
(
4
,)
attrs
=
{
"shape"
:
[
"list_shape"
,
[
shape
,
shape
]],
"dltype"
:
[
"list_str"
,
[
"float32"
,
"float32"
]],
"storage_id"
:
[
"list_int"
,
[
0
,
1
]],
}
graph
=
{
"nodes"
:
nodes
,
"arg_nodes"
:
arg_nodes
,
"node_row_ptr"
:
node_row_ptr
,
"heads"
:
outputs
,
"attrs"
:
attrs
}
graph
=
json
.
dumps
(
graph
)
mlib
=
tvm
.
build
(
sched
,
[
A
,
B
],
"llvm"
,
name
=
"myadd"
)
mlib
.
export_library
(
os
.
path
.
join
(
target_dir
,
"graph_addone_lib.so"
))
with
open
(
os
.
path
.
join
(
target_dir
,
"graph_addone.json"
),
"w"
)
as
fo
:
fo
.
write
(
graph
)
if
__name__
==
"__main__"
:
import
sys
if
len
(
sys
.
argv
)
!=
2
:
sys
.
exit
(
-
1
)
dump_graph_lib
(
sys
.
argv
[
1
])
jvm/native/src/main/native/jni_helper_func.h
View file @
5bd4afee
...
...
@@ -72,6 +72,14 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) {
return
ret
;
}
jobject
newTVMValueHandle
(
JNIEnv
*
env
,
jlong
value
)
{
jclass
cls
=
env
->
FindClass
(
"ml/dmlc/tvm/TVMValueHandle"
);
jmethodID
constructor
=
env
->
GetMethodID
(
cls
,
"<init>"
,
"(J)V"
);
jobject
object
=
env
->
NewObject
(
cls
,
constructor
,
value
);
env
->
DeleteLocalRef
(
cls
);
return
object
;
}
jobject
newTVMValueLong
(
JNIEnv
*
env
,
jlong
value
)
{
jclass
cls
=
env
->
FindClass
(
"ml/dmlc/tvm/TVMValueLong"
);
jmethodID
constructor
=
env
->
GetMethodID
(
cls
,
"<init>"
,
"(J)V"
);
...
...
@@ -166,6 +174,8 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
return
newTVMValueLong
(
env
,
static_cast
<
jlong
>
(
value
.
v_int64
));
case
kDLFloat
:
return
newTVMValueDouble
(
env
,
static_cast
<
jdouble
>
(
value
.
v_float64
));
case
kHandle
:
return
newTVMValueHandle
(
env
,
reinterpret_cast
<
jlong
>
(
value
.
v_handle
));
case
kModuleHandle
:
return
newModule
(
env
,
reinterpret_cast
<
jlong
>
(
value
.
v_handle
));
case
kFuncHandle
:
...
...
tests/scripts/task_java_unittest.sh
View file @
5bd4afee
...
...
@@ -8,6 +8,7 @@ TEMP_DIR=$(mktemp -d)
python
$SCRIPT_DIR
/test_add_cpu.py
$TEMP_DIR
||
exit
-1
python
$SCRIPT_DIR
/test_add_gpu.py
$TEMP_DIR
||
exit
-1
python
$SCRIPT_DIR
/test_graph_runtime.py
$TEMP_DIR
||
exit
-1
# start rpc proxy server
PORT
=
$((
(
RANDOM
%
1000
)
+
9000
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment