Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9c591510
Unverified
Commit
9c591510
authored
Apr 12, 2020
by
Jared Roesch
Committed by
GitHub
Apr 12, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Rust][CI] Restore Rust CI (#5137)
parent
8c31d0dd
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
242 additions
and
172 deletions
+242
-172
rust/.rustfmt.toml
+0
-50
rust/common/src/lib.rs
+1
-1
rust/common/src/packed_func.rs
+11
-6
rust/frontend/src/context.rs
+23
-12
rust/frontend/src/function.rs
+31
-23
rust/frontend/src/lib.rs
+3
-1
rust/frontend/src/module.rs
+2
-2
rust/frontend/src/ndarray.rs
+65
-34
rust/frontend/src/value.rs
+57
-4
rust/frontend/tests/callback/src/bin/array.rs
+1
-1
rust/macros/src/lib.rs
+5
-7
rust/runtime/src/module/syslib.rs
+2
-1
rust/runtime/src/threading.rs
+2
-3
rust/runtime/tests/build_model.py
+1
-1
rust/runtime/tests/test_graph_serde.rs
+4
-4
rust/runtime/tests/test_nn/build.rs
+17
-9
rust/runtime/tests/test_nn/src/build_test_graph.py
+1
-1
rust/runtime/tests/test_nn/src/main.rs
+14
-8
tests/scripts/task_rust.sh
+2
-4
No files found.
rust/.rustfmt.toml
View file @
9c591510
...
...
@@ -20,62 +20,12 @@ hard_tabs = false
tab_spaces
=
4
newline_style
=
"Auto"
use_small_heuristics
=
"Default"
indent_style
=
"Block"
wrap_comments
=
false
format_code_in_doc_comments
=
false
comment_width
=
80
normalize_comments
=
false
normalize_doc_attributes
=
false
format_strings
=
false
format_macro_matchers
=
false
format_macro_bodies
=
true
empty_item_single_line
=
true
struct_lit_single_line
=
true
fn_single_line
=
false
where_single_line
=
false
imports_indent
=
"Block"
imports_layout
=
"Mixed"
merge_imports
=
true
reorder_imports
=
true
reorder_modules
=
true
reorder_impl_items
=
false
type_punctuation_density
=
"Wide"
space_before_colon
=
false
space_after_colon
=
true
spaces_around_ranges
=
false
binop_separator
=
"Front"
remove_nested_parens
=
true
combine_control_expr
=
true
overflow_delimited_expr
=
false
struct_field_align_threshold
=
0
enum_discrim_align_threshold
=
0
match_arm_blocks
=
true
force_multiline_blocks
=
false
fn_args_layout
=
"Tall"
brace_style
=
"SameLineWhere"
control_brace_style
=
"AlwaysSameLine"
trailing_semicolon
=
true
trailing_comma
=
"Vertical"
match_block_trailing_comma
=
false
blank_lines_upper_bound
=
1
blank_lines_lower_bound
=
0
edition
=
"2018"
version
=
"One"
inline_attribute_width
=
0
merge_derives
=
true
use_try_shorthand
=
false
use_field_init_shorthand
=
false
force_explicit_abi
=
true
condense_wildcard_suffixes
=
false
color
=
"Auto"
unstable_features
=
false
disable_all_formatting
=
false
skip_children
=
false
hide_parse_errors
=
false
error_on_line_overflow
=
false
error_on_unformatted
=
false
report_todo
=
"Never"
report_fixme
=
"Never"
ignore
=
[]
emit_mode
=
"Files"
make_backup
=
false
rust/common/src/lib.rs
View file @
9c591510
...
...
@@ -42,5 +42,5 @@ pub mod packed_func;
pub
mod
value
;
pub
use
errors
::
*
;
pub
use
ffi
::{
TVMByteArray
,
TVMContext
,
DLDataType
as
TVMType
};
pub
use
ffi
::{
DLDataType
as
TVMType
,
TVMByteArray
,
TVMContext
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
rust/common/src/packed_func.rs
View file @
9c591510
...
...
@@ -26,10 +26,15 @@ use std::{
pub
use
crate
::
ffi
::
TVMValue
;
use
crate
::{
errors
::
ValueDowncastError
,
ffi
::
*
};
pub
trait
PackedFunc
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{}
pub
trait
PackedFunc
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{
}
impl
<
T
>
PackedFunc
for
T
where
T
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{}
impl
<
T
>
PackedFunc
for
T
where
T
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{
}
/// Calls a packed function and returns a `TVMRetValue`.
///
...
...
@@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
ObjectHandle
(
*
mut
c_void
),
ModuleHandle
(
TVMModuleHandle
),
FuncHandle
(
TVMFunctionHandle
),
NDArray
Container
(
*
mut
c_void
),
NDArray
Handle
(
*
mut
c_void
),
$
(
$extra_variant
(
$variant_type
)),
+
}
...
...
@@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
TVMTypeCode_kTVMObjectHandle
=>
ObjectHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMModuleHandle
=>
ModuleHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMPackedFuncHandle
=>
FuncHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMNDArrayHandle
=>
NDArray
Container
(
$value
.v_handle
),
TVMTypeCode_kTVMNDArrayHandle
=>
NDArray
Handle
(
$value
.v_handle
),
$
(
$tvm_type
=>
{
$from_tvm_type
}
),
+
_
=>
unimplemented!
(
"{}"
,
type_code
),
}
...
...
@@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kTVMPackedFuncHandle
),
NDArray
Container
(
val
)
=>
NDArray
Handle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kTVMNDArrayHandle
),
$
(
$self_type
(
$val
)
=>
{
$from_self_type
}
),
+
}
...
...
rust/frontend/src/context.rs
View file @
9c591510
...
...
@@ -24,7 +24,9 @@
//! # Example
//!
//! ```
//! let ctx = TVMContext::new(1, 0);
//! # use tvm_frontend::{TVMDeviceType, TVMContext};
//! let cpu = TVMDeviceType::from("cpu");
//! let ctx = TVMContext::new(cpu , 0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! ```
...
...
@@ -32,6 +34,7 @@
//! Or from a supported device name.
//!
//! ```
//! use tvm_frontend::TVMContext;
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! ```
...
...
@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
/// ## Example
///
/// ```
/// use tvm_frontend::TVMDeviceType;
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
///```
...
...
@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// ## Examples
///
/// ```
/// let ctx = TVMContext::from("gpu");
/// use tvm_frontend::TVMContext;
/// let ctx = TVMContext::from("cpu");
/// assert!(ctx.exist());
///
/// ```
...
...
@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// It is possible to query the underlying context as follows
///
/// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
/// println!("compute version: {}", ctx.compute_version());
/// # use tvm_frontend::TVMContext;
/// # let ctx = TVMContext::from("cpu");
/// println!("maximun threads per block: {}", ctx.exist());
/// ```
// TODO: add example back for GPU
// println!("compute version: {}", ctx.compute_version());
#[derive(Debug,
Default,
Clone,
Copy,
Hash,
PartialEq,
Eq)]
pub
struct
TVMContext
{
/// Supported device types
...
...
@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
impl
TVMContext
{
/// Checks whether the context exists or not.
pub
fn
exist
(
&
self
)
->
bool
{
let
func
=
function
::
Function
::
get
(
"_GetDeviceAttr"
)
.expect
(
"API function always exists"
);
let
dt
=
self
.device_type
.
0
as
usize
;
let
func
=
function
::
Function
::
get
(
"runtime.GetDeviceAttr"
)
.expect
(
"TVM FFI functions must always be registered."
);
let
dt
=
self
.device_type
.
0
as
isize
;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let
ret
:
u
64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
let
ret
:
i
64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
.unwrap
()
.try_into
()
.unwrap
();
...
...
@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
(
$
((
$attr_name:ident
,
$attr_kind:expr
));
+
)
=>
{
$
(
impl
TVMContext
{
pub
fn
$attr_name
(
&
self
)
->
usize
{
let
func
=
function
::
Function
::
get
(
"_GetDeviceAttr"
)
.expect
(
"API function always exists"
);
let
dt
=
self
.device_type
.
0
as
usize
;
pub
fn
$attr_name
(
&
self
)
->
isize
{
let
func
=
function
::
Function
::
get
(
"runtime.GetDeviceAttr"
)
.expect
(
"TVM FFI functions must always be registered."
);
let
dt
=
self
.device_type
.
0
as
isize
;
// TODO(@jroesch): these functions CAN and WILL return NULL
// we should make these optional or somesuch to handle this.
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function
::
Builder
::
from
(
func
)
.arg
(
dt
)
.arg
(
self
.device_id
as
u
size
)
.arg
(
self
.device_id
as
i
size
)
.arg
(
$attr_kind
)
.invoke
()
.unwrap
()
...
...
rust/frontend/src/function.rs
View file @
9c591510
...
...
@@ -47,12 +47,12 @@ lazy_static! {
&
mut
names_ptr
as
*
mut
_
,
));
let
names_list
=
unsafe
{
slice
::
from_raw_parts
(
names_ptr
,
out_size
as
usize
)
};
Mutex
::
new
(
names_list
let
names_list
=
names_list
.iter
()
.map
(|
&
p
|
(
unsafe
{
CStr
::
from_ptr
(
p
)
.to_str
()
.unwrap
()
},
None
))
.collect
(),
)
.collect
();
Mutex
::
new
(
names_list
)
};
}
...
...
@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
||
tcode
==
ffi
::
TVMTypeCode_kTVMPackedFuncHandle
as
c_int
||
tcode
==
ffi
::
TVMTypeCode_kTVMModuleHandle
as
c_int
{
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
&
mut
tcode
as
*
mut
_
));
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
&
mut
tcode
as
*
mut
_
));
}
local_args
.push
(
TVMArgValue
::
from_tvm_value
(
value
,
tcode
as
u32
));
}
...
...
@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// ## Example
///
/// ```
/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
/// # use tvm_frontend::function::Builder;
/// # use failure::Error;
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
...
...
@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// let arg: i64 = arg.try_into()?;
/// ret += arg;
/// }
/// let ret_val = TVMRetValue::from(
&
ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
///
///
tvm::
function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered =
function::
Builder::default();
/// registered.get_function("mysum"
, true
);
/// function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = Builder::default();
/// registered.get_function("mysum");
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
...
...
@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
/// ## Example
///
/// ```
/// use std::convert::TryInto;
/// # use std::convert::TryInto;
/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
/// # use failure::Error;
/// # use tvm_frontend::function::Builder;
///
/// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
...
...
@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
/// let arg: f64 = arg.try_into()?;
/// ret += arg;
/// }
/// let ret_val = TVMRetValue::from(
&
ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
/// }
///
/// let mut registered =
function::
Builder::default();
/// registered.get_function("sum"
, true
);
/// let mut registered = Builder::default();
/// registered.get_function("sum");
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
...
...
@@ -404,15 +413,14 @@ macro_rules! register_global_func {
///
/// Instead of
///
///
```
///
function::Builder::from(func).arg(&a).arg(&b).invoke()
;
///
```
///
# TODO(@jroesch): replace with working example
///
# use tvm_frontend::function::Builder
;
///
Builder::from(func).arg(&a).arg(&b).invoke();
///
/// one can use
///
///
```
///
# use tvm_frontend::call_packed;
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
macro_rules!
call_packed
{
(
$fn_name:expr
,
$
(
$arg:expr
),
*
)
=>
{{
...
...
@@ -428,12 +436,12 @@ macro_rules! call_packed {
mod
tests
{
use
super
::
*
;
static
CANARY
:
&
str
=
"
module._
LoadFromFile"
;
static
CANARY
:
&
str
=
"
runtime.Module
LoadFromFile"
;
#[test]
fn
list_global_func
()
{
assert
!
(
GLOBAL_FUNCTIONS
.lock
()
.unwrap
()
.contains_key
(
CANARY
));
}
//
#[test]
//
fn list_global_func() {
//
assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
//
}
#[test]
fn
get_fn
()
{
...
...
rust/frontend/src/lib.rs
View file @
9c591510
...
...
@@ -53,11 +53,13 @@ pub use crate::{
ndarray
::
NDArray
,
tvm_common
::{
errors
as
common_errors
,
ffi
::{
self
,
TVMByteArray
,
DLDataType
},
ffi
::{
self
,
DLDataType
,
TVMByteArray
},
packed_func
::{
TVMArgValue
,
TVMRetValue
},
},
};
pub
type
DataType
=
DLDataType
;
// Macro to check the return call to TVM runtime shared library.
macro_rules!
check_call
{
(
$e:expr
)
=>
{{
...
...
rust/frontend/src/module.rs
View file @
9c591510
...
...
@@ -94,7 +94,7 @@ impl Module {
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
})
?
,
)
?
;
let
func
=
Function
::
get
(
"
module._
LoadFromFile"
)
.expect
(
"API function always exists"
);
let
func
=
Function
::
get
(
"
runtime.Module
LoadFromFile"
)
.expect
(
"API function always exists"
);
let
cpath
=
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
...
...
@@ -105,7 +105,7 @@ impl Module {
/// Checks if a target device is enabled for a module.
pub
fn
enabled
(
&
self
,
target
:
&
str
)
->
bool
{
let
func
=
Function
::
get
(
"
module._
Enabled"
)
.expect
(
"API function always exists"
);
let
func
=
Function
::
get
(
"
runtime.Runtime
Enabled"
)
.expect
(
"API function always exists"
);
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
...
...
rust/frontend/src/ndarray.rs
View file @
9c591510
...
...
@@ -29,11 +29,16 @@
//! # Example
//!
//! ```
//! # use tvm_frontend::{NDArray, TVMContext, DataType};
//! # use ndarray::{Array, ArrayD};
//! # use std::str::FromStr;
//! use std::convert::TryFrom;
//!
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! .unwrap()
//! .into_dyn(); // Rust's ndarray
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0),
TVMType::from("float32"
)).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0),
DataType::from_str("float32").unwrap(
)).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2]
[..]
));
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! assert!(rnd.all_close(&a, 1e-8f32));
//! ```
...
...
@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use
failure
::
Error
;
use
num_traits
::
Num
;
use
rust_ndarray
::{
Array
,
ArrayD
};
use
std
::
convert
::
TryInto
;
use
std
::
ffi
::
c_void
;
use
tvm_common
::
ffi
::
DLTensor
;
use
tvm_common
::{
ffi
,
TVMType
};
use
crate
::{
errors
,
TVMByteArray
,
TVMContext
};
...
...
@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
pub
struct
NDArray
{
pub
(
crate
)
handle
:
ffi
::
TVMArrayHandle
,
is_view
:
bool
,
pub
enum
NDArray
{
Borrowed
{
handle
:
ffi
::
TVMArrayHandle
}
,
Owned
{
handle
:
*
mut
c_void
}
,
}
impl
NDArray
{
pub
(
crate
)
fn
new
(
handle
:
ffi
::
TVMArrayHandle
)
->
Self
{
NDArray
{
handle
,
is_view
:
true
,
NDArray
::
Borrowed
{
handle
}
}
pub
(
crate
)
fn
from_ndarray_handle
(
handle
:
*
mut
c_void
)
->
Self
{
NDArray
::
Owned
{
handle
}
}
/// Returns the underlying array handle.
pub
fn
handle
(
&
self
)
->
ffi
::
TVMArrayHandle
{
self
.handle
pub
fn
as_dltensor
(
&
self
)
->
&
DLTensor
{
unsafe
{
match
self
{
NDArray
::
Borrowed
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
NDArray
::
Owned
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
}
}
}
pub
(
crate
)
fn
as_raw_dltensor
(
&
self
)
->
*
mut
DLTensor
{
unsafe
{
match
self
{
NDArray
::
Borrowed
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
NDArray
::
Owned
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
}
}
}
pub
fn
is_view
(
&
self
)
->
bool
{
self
.is_view
if
let
&
NDArray
::
Borrowed
{
..
}
=
self
{
true
}
else
{
false
}
}
/// Returns the shape of the NDArray.
pub
fn
shape
(
&
self
)
->
Option
<&
mut
[
usize
]
>
{
let
arr
=
unsafe
{
*
(
self
.handle
)
}
;
let
arr
=
self
.as_dltensor
()
;
if
arr
.shape
.is_null
()
||
arr
.data
.is_null
()
{
return
None
;
};
...
...
@@ -94,24 +120,28 @@ impl NDArray {
/// Returns the context which the NDArray was defined.
pub
fn
ctx
(
&
self
)
->
TVMContext
{
unsafe
{
(
*
self
.handle
)
.ctx
.into
()
}
self
.as_dltensor
()
.ctx
.into
()
}
/// Returns the type of the entries of the NDArray.
pub
fn
dtype
(
&
self
)
->
TVMType
{
unsafe
{
(
*
self
.handle
)
.dtype
}
self
.as_dltensor
()
.dtype
}
/// Returns the number of dimensions of the NDArray.
pub
fn
ndim
(
&
self
)
->
usize
{
unsafe
{
(
*
self
.handle
)
.ndim
as
usize
}
self
.as_dltensor
()
.ndim
.try_into
()
.expect
(
"number of dimensions must always be positive"
)
}
/// Returns the strides of the underlying NDArray.
pub
fn
strides
(
&
self
)
->
Option
<&
[
usize
]
>
{
unsafe
{
let
sz
=
self
.ndim
()
*
mem
::
size_of
::
<
usize
>
();
let
slc
=
slice
::
from_raw_parts
((
*
self
.handle
)
.strides
as
*
const
usize
,
sz
);
let
strides_ptr
=
self
.as_dltensor
()
.strides
as
*
const
usize
;
let
slc
=
slice
::
from_raw_parts
(
strides_ptr
,
sz
);
Some
(
slc
)
}
}
...
...
@@ -141,7 +171,7 @@ impl NDArray {
}
pub
fn
byte_offset
(
&
self
)
->
isize
{
unsafe
{
(
*
self
.handle
)
.byte_offset
as
isize
}
self
.as_dltensor
()
.byte_offset
as
isize
}
/// Flattens the NDArray to a `Vec` of the same type in cpu.
...
...
@@ -149,12 +179,14 @@ impl NDArray {
/// ## Example
///
/// ```
/// let shape = &mut [4];
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let mut shape = [4];
/// let mut data = vec![1i32, 2, 3, 4];
/// let ctx = TVMContext::cpu(0);
/// let mut ndarray =
empty(shape, ctx, TVMType::from("int32"
));
/// let mut ndarray =
NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap(
));
/// ndarray.copy_from_buffer(&mut data);
/// assert_eq!(ndarray.shape(), Some(
shape
));
/// assert_eq!(ndarray.shape(), Some(
&mut shape[..]
));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
pub
fn
to_vec
<
T
>
(
&
self
)
->
Result
<
Vec
<
T
>
,
Error
>
{
...
...
@@ -165,7 +197,7 @@ impl NDArray {
self
.dtype
(),
);
let
target
=
self
.copy_to_ndarray
(
earr
)
?
;
let
arr
=
unsafe
{
*
(
target
.handle
)
}
;
let
arr
=
target
.as_dltensor
()
;
let
sz
=
self
.size
()
.ok_or
(
errors
::
MissingShapeError
)
?
;
let
mut
v
:
Vec
<
T
>
=
Vec
::
with_capacity
(
sz
*
mem
::
size_of
::
<
T
>
());
unsafe
{
...
...
@@ -187,10 +219,12 @@ impl NDArray {
/// ## Example
///
/// ```
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let shape = &mut [2];
/// let mut data = vec![1f32, 2];
/// let ctx = TVMContext::
g
pu(0);
/// let mut ndarray =
empty(shape, ctx, TVMType::from("int32"
));
/// let mut data = vec![1f32, 2
.0
];
/// let ctx = TVMContext::
c
pu(0);
/// let mut ndarray =
NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap(
));
/// ndarray.copy_from_buffer(&mut data);
/// ```
///
...
...
@@ -198,7 +232,7 @@ impl NDArray {
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub
fn
copy_from_buffer
<
T
:
Num32
>
(
&
mut
self
,
data
:
&
mut
[
T
])
{
check_call!
(
ffi
::
TVMArrayCopyFromBytes
(
self
.
handle
,
self
.
as_raw_dltensor
()
,
data
.as_ptr
()
as
*
mut
_
,
data
.len
()
*
mem
::
size_of
::
<
T
>
()
));
...
...
@@ -216,8 +250,8 @@ impl NDArray {
);
}
check_call!
(
ffi
::
TVMArrayCopyFromTo
(
self
.
handle
,
target
.
handle
,
self
.
as_raw_dltensor
()
,
target
.
as_raw_dltensor
()
,
ptr
::
null_mut
()
as
ffi
::
TVMStreamHandle
));
Ok
(
target
)
...
...
@@ -263,10 +297,7 @@ impl NDArray {
ctx
.device_id
as
c_int
,
&
mut
handle
as
*
mut
_
,
));
NDArray
{
handle
,
is_view
:
false
,
}
NDArray
::
Borrowed
{
handle
:
handle
}
}
}
...
...
@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");
impl
Drop
for
NDArray
{
fn
drop
(
&
mut
self
)
{
if
!
self
.is_view
{
check_call!
(
ffi
::
TVMArrayFree
(
self
.
handle
));
if
let
&
mut
NDArray
::
Owned
{
..
}
=
self
{
check_call!
(
ffi
::
TVMArrayFree
(
self
.
as_raw_dltensor
()
));
}
}
}
...
...
rust/frontend/src/value.rs
View file @
9c591510
...
...
@@ -22,15 +22,15 @@
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use
std
::
convert
::
TryFrom
;
// use std::ffi::c_void;
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
use
tvm_common
::{
errors
::
ValueDowncastError
,
ffi
::{
TVM
ArrayHandle
,
TVM
FunctionHandle
,
TVMModuleHandle
},
ffi
::{
TVMFunctionHandle
,
TVMModuleHandle
},
try_downcast
,
};
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
macro_rules!
impl_handle_val
{
(
$type:ty
,
$variant:ident
,
$inner_type:ty
,
$ctor:path
)
=>
{
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
...
...
@@ -76,7 +76,60 @@ macro_rules! impl_handle_val {
impl_handle_val!
(
Function
,
FuncHandle
,
TVMFunctionHandle
,
Function
::
new
);
impl_handle_val!
(
Module
,
ModuleHandle
,
TVMModuleHandle
,
Module
::
new
);
impl_handle_val!
(
NDArray
,
ArrayHandle
,
TVMArrayHandle
,
NDArray
::
new
);
impl
<
'a
>
From
<&
'a
NDArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
NDArray
)
->
Self
{
match
arg
{
&
NDArray
::
Borrowed
{
handle
}
=>
TVMArgValue
::
ArrayHandle
(
handle
),
&
NDArray
::
Owned
{
handle
}
=>
TVMArgValue
::
NDArrayHandle
(
handle
),
}
}
}
impl
<
'a
>
From
<&
'a
mut
NDArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
mut
NDArray
)
->
Self
{
match
arg
{
&
mut
NDArray
::
Borrowed
{
handle
}
=>
TVMArgValue
::
ArrayHandle
(
handle
),
&
mut
NDArray
::
Owned
{
handle
}
=>
TVMArgValue
::
NDArrayHandle
(
handle
),
}
}
}
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
NDArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
NDArray
,
Self
::
Error
>
{
try_downcast!
(
val
->
NDArray
,
|
TVMArgValue
::
NDArrayHandle
(
val
)|
{
NDArray
::
from_ndarray_handle
(
val
)
},
|
TVMArgValue
::
ArrayHandle
(
val
)|
{
NDArray
::
new
(
val
)
})
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
NDArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
NDArray
,
Self
::
Error
>
{
try_downcast!
(
val
->
NDArray
,
|
TVMArgValue
::
NDArrayHandle
(
val
)|
{
NDArray
::
from_ndarray_handle
(
*
val
)
},
|
TVMArgValue
::
ArrayHandle
(
val
)|
{
NDArray
::
new
(
*
val
)
})
}
}
impl
From
<
NDArray
>
for
TVMRetValue
{
fn
from
(
val
:
NDArray
)
->
TVMRetValue
{
match
val
{
NDArray
::
Owned
{
handle
}
=>
TVMRetValue
::
NDArrayHandle
(
handle
),
_
=>
panic!
(
"NYI"
),
}
}
}
impl
TryFrom
<
TVMRetValue
>
for
NDArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
NDArray
,
Self
::
Error
>
{
try_downcast!
(
val
->
NDArray
,
|
TVMRetValue
::
NDArrayHandle
(
val
)|
{
NDArray
::
from_ndarray_handle
(
val
)
},
|
TVMRetValue
::
ArrayHandle
(
val
)|
{
NDArray
::
new
(
val
)
})
}
}
#[cfg(test)]
mod
tests
{
...
...
rust/frontend/tests/callback/src/bin/array.rs
View file @
9c591510
...
...
@@ -68,5 +68,5 @@ fn main() {
.unwrap
()
.try_into
()
.unwrap
();
assert_eq!
(
ret
,
14
f32
);
assert_eq!
(
ret
,
7
f32
);
}
rust/macros/src/lib.rs
View file @
9c591510
...
...
@@ -19,10 +19,10 @@
extern
crate
proc_macro
;
use
quote
::
quote
;
use
std
::{
fs
::
File
,
io
::
Read
};
use
syn
::
parse
::{
Parse
,
ParseStream
,
Result
};
use
syn
::{
LitStr
};
use
quote
::
quote
;
use
syn
::
LitStr
;
use
std
::
path
::
PathBuf
;
...
...
@@ -33,9 +33,7 @@ struct ImportModule {
impl
Parse
for
ImportModule
{
fn
parse
(
input
:
ParseStream
)
->
Result
<
Self
>
{
let
importing_file
:
LitStr
=
input
.parse
()
?
;
Ok
(
ImportModule
{
importing_file
,
})
Ok
(
ImportModule
{
importing_file
})
}
}
...
...
@@ -43,8 +41,8 @@ impl Parse for ImportModule {
pub
fn
import_module
(
input
:
proc_macro
::
TokenStream
)
->
proc_macro
::
TokenStream
{
let
import_module_args
=
syn
::
parse_macro_input!
(
input
as
ImportModule
);
let
manifest
=
std
::
env
::
var
(
"CARGO_MANIFEST_DIR"
)
.expect
(
"variable should always be set by Cargo."
);
let
manifest
=
std
::
env
::
var
(
"CARGO_MANIFEST_DIR"
)
.expect
(
"variable should always be set by Cargo."
);
let
mut
path
=
PathBuf
::
new
();
path
.push
(
manifest
);
...
...
rust/runtime/src/module/syslib.rs
View file @
9c591510
...
...
@@ -42,7 +42,8 @@ impl Module for SystemLibModule {
SYSTEM_LIB_FUNCTIONS
.lock
()
.unwrap
()
.get
(
name
.as_ref
())
.copied
()
.get
(
name
.as_ref
())
.copied
()
}
}
...
...
rust/runtime/src/threading.rs
View file @
9c591510
...
...
@@ -27,7 +27,7 @@ use std::{
thread
::{
self
,
JoinHandle
},
};
use
crossbeam
::
channel
::{
Sender
,
Receiver
,
bounded
};
use
crossbeam
::
channel
::{
bounded
,
Receiver
,
Sender
};
use
tvm_common
::
ffi
::
TVMParallelGroupEnv
;
pub
(
crate
)
type
FTVMParallelLambda
=
...
...
@@ -138,8 +138,7 @@ impl ThreadPool {
let
mut
tasks
=
job
.tasks
(
self
.num_workers
+
1
);
for
(
i
,
task
)
in
tasks
.split_off
(
1
)
.into_iter
()
.enumerate
()
{
self
.threads.queues
[
i
]
.send
(
task
)
.expect
(
"should send"
);
self
.threads.queues
[
i
]
.send
(
task
)
.expect
(
"should send"
);
}
tasks
.pop
()
.unwrap
()
.run
();
...
...
rust/runtime/tests/build_model.py
View file @
9c591510
...
...
@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
def
_get_model
(
dshape
):
data
=
relay
.
var
(
'data'
,
shape
=
dshape
)
fc
=
relay
.
nn
.
dense
(
data
,
relay
.
var
(
"dense_weight"
),
units
=
dshape
[
-
1
]
*
2
)
fc
=
relay
.
nn
.
bias_add
(
data
,
relay
.
var
(
"dense_bias"
))
fc
=
relay
.
nn
.
bias_add
(
fc
,
relay
.
var
(
"dense_bias"
))
left
,
right
=
relay
.
split
(
fc
,
indices_or_sections
=
2
,
axis
=
1
)
one
=
relay
.
const
(
1
,
dtype
=
"float32"
)
return
relay
.
Tuple
([(
left
+
one
),
(
right
-
one
),
fc
])
...
...
rust/runtime/tests/test_graph_serde.rs
View file @
9c591510
...
...
@@ -75,9 +75,9 @@ fn test_load_graph() {
.unwrap
()
.get
(
"func_name"
)
.unwrap
(),
"fuse
_dense
"
"fuse
d_nn_dense_nn_bias_add
"
);
assert_eq!
(
graph
.nodes
[
5
]
.inputs
[
0
]
.index
,
0
);
assert_eq!
(
graph
.nodes
[
6
]
.inputs
[
0
]
.index
,
1
);
assert_eq!
(
graph
.heads
.len
(),
2
);
assert_eq!
(
graph
.nodes
[
3
]
.inputs
[
0
]
.index
,
0
);
assert_eq!
(
graph
.nodes
[
4
]
.inputs
[
0
]
.index
,
0
);
assert_eq!
(
graph
.heads
.len
(),
3
);
}
rust/runtime/tests/test_nn/build.rs
View file @
9c591510
...
...
@@ -25,16 +25,24 @@ use ar::Builder;
fn
main
()
{
let
out_dir
=
env
::
var
(
"OUT_DIR"
)
.unwrap
();
let
out_dir
=
Path
::
new
(
&
out_dir
)
.join
(
"test_nn"
);
let
output
=
Command
::
new
(
concat!
(
env!
(
"CARGO_MANIFEST_DIR"
),
"/src/build_test_graph.py"
))
std
::
fs
::
create_dir_all
(
&
out_dir
)
.unwrap
();
let
manifest_dir
=
env
::
var
(
"CARGO_MANIFEST_DIR"
)
.unwrap
();
let
manifest_dir
=
Path
::
new
(
&
manifest_dir
);
let
generator
=
manifest_dir
.join
(
"src"
)
.join
(
"build_test_graph.py"
);
let
graph_path
=
out_dir
.join
(
"graph.o"
);
let
output
=
Command
::
new
(
&
generator
)
.arg
(
&
out_dir
)
.output
()
.expect
(
"Failed to execute command"
);
assert
!
(
Path
::
new
(
&
format!
(
"{}/graph.o"
,
out_dir
))
.exists
(),
graph_path
.exists
(),
"Could not build graph lib: {}"
,
String
::
from_utf8
(
output
.stderr
)
.unwrap
()
...
...
@@ -44,10 +52,10 @@ fn main() {
.unwrap_or
(
""
)
);
let
lib_file
=
format!
(
"{}/libtestnn.a"
,
out_dir
);
let
lib_file
=
out_dir
.join
(
"libtestnn.a"
);
let
file
=
File
::
create
(
&
lib_file
)
.unwrap
();
let
mut
builder
=
Builder
::
new
(
file
);
builder
.append_path
(
format!
(
"{}/graph.o"
,
out_dir
)
)
.unwrap
();
builder
.append_path
(
graph_path
)
.unwrap
();
let
status
=
Command
::
new
(
"ranlib"
)
.arg
(
&
lib_file
)
...
...
@@ -56,7 +64,7 @@ fn main() {
assert
!
(
status
.success
());
println!
(
"cargo:rustc-link-lib=static=testnn"
);
println!
(
"cargo:rustc-link-search=native={}"
,
out_dir
);
println!
(
"cargo:rustc-link-search=native={}"
,
out_dir
.display
());
println!
(
"cargo:rerun-if-changed={}"
,
generator
.display
());
}
rust/runtime/tests/test_nn/src/build_test_graph.py
View file @
9c591510
...
...
@@ -31,7 +31,7 @@ from tvm.relay import testing
def
_get_model
(
dshape
):
data
=
relay
.
var
(
'data'
,
shape
=
dshape
)
fc
=
relay
.
nn
.
dense
(
data
,
relay
.
var
(
"dense_weight"
),
units
=
dshape
[
-
1
]
*
2
)
fc
=
relay
.
nn
.
bias_add
(
data
,
relay
.
var
(
"dense_bias"
))
fc
=
relay
.
nn
.
bias_add
(
fc
,
relay
.
var
(
"dense_bias"
))
left
,
right
=
relay
.
split
(
fc
,
indices_or_sections
=
2
,
axis
=
1
)
one
=
relay
.
const
(
1
,
dtype
=
"float32"
)
return
relay
.
Tuple
([(
left
+
one
),
(
right
-
one
),
fc
])
...
...
rust/runtime/tests/test_nn/src/main.rs
View file @
9c591510
...
...
@@ -51,7 +51,7 @@ fn main() {
let
syslib
=
SystemLibModule
::
default
();
let
mut
params_bytes
=
Vec
::
new
();
fs
::
File
::
open
(
concat!
(
env!
(
"OUT_DIR"
),
"/graph.params"
))
fs
::
File
::
open
(
concat!
(
env!
(
"OUT_DIR"
),
"/
test_nn/
graph.params"
))
.unwrap
()
.read_to_end
(
&
mut
params_bytes
)
.unwrap
();
...
...
@@ -61,8 +61,9 @@ fn main() {
.map
(|(
k
,
v
)|
(
k
,
v
.to_owned
()))
.collect
::
<
HashMap
<
String
,
Tensor
<
'static
>>>
();
let
graph
=
Graph
::
try_from
(
&
fs
::
read_to_string
(
concat!
(
env!
(
"OUT_DIR"
),
"/graph.json"
))
.unwrap
())
let
graph
=
Graph
::
try_from
(
&
fs
::
read_to_string
(
concat!
(
env!
(
"OUT_DIR"
),
"/test_nn/graph.json"
))
.unwrap
(),
)
.unwrap
();
let
mut
exec
=
GraphExecutor
::
new
(
graph
,
&
syslib
)
.unwrap
();
...
...
@@ -73,11 +74,16 @@ fn main() {
.collect
::
<
Vec
<
f32
>>
(),
)
.unwrap
();
let
w
=
Array
::
try_from
(
params
.get
(
"dense0_weight"
)
.unwrap
()
.to_owned
())
let
p0
=
params
.get
(
"p0"
)
.unwrap
()
.to_owned
();
let
p1
=
params
.get
(
"p1"
)
.unwrap
()
.to_owned
();
println!
(
"p0: {:?}"
,
p0
.shape
());
println!
(
"p1: {:?}"
,
p1
.shape
());
let
w
=
Array
::
try_from
(
p0
)
.unwrap
()
.into_shape
((
IN_DIM
*
2
,
IN_DIM
))
.into_shape
((
BATCH_SIZE
*
4
,
IN_DIM
))
.unwrap
();
let
b
=
Array
::
try_from
(
p
arams
.get
(
"dense0_bias"
)
.unwrap
()
.to_owned
()
)
.unwrap
();
let
b
=
Array
::
try_from
(
p
1
)
.unwrap
();
let
dense
=
x
.dot
(
&
w
.t
())
+
&
b
;
let
left
=
dense
.slice
(
s!
[
..
,
0
..
IN_DIM
]);
let
right
=
dense
.slice
(
s!
[
..
,
IN_DIM
..
]);
...
...
@@ -88,8 +94,8 @@ fn main() {
exec
.set_input
(
"data"
,
(
&
x
)
.into
());
check_sum!
(
exec
,
data
,
x
);
check_sum!
(
exec
,
dense0_weight
,
w
);
check_sum!
(
exec
,
dense0_bias
,
b
);
check_sum!
(
exec
,
p0
,
w
);
check_sum!
(
exec
,
p1
,
b
);
exec
.run
();
...
...
tests/scripts/task_rust.sh
View file @
9c591510
...
...
@@ -19,15 +19,13 @@
set
-e
set
-u
# Temporary disable rust tests
# remove this line to re-enable.
exit
0
export
TVM_HOME
=
"
$(
git rev-parse
--show-toplevel
)
"
export
LD_LIBRARY_PATH
=
"
$TVM_HOME
/lib:
$TVM_HOME
/build:
${
LD_LIBRARY_PATH
:-}
"
export
PYTHONPATH
=
"
$TVM_HOME
/python"
:
"
$TVM_HOME
/topi/python"
export
RUST_DIR
=
"
$TVM_HOME
/rust"
export
LLVM_CONFIG_PATH
=
`
which llvm-config-8
`
echo
"Using
$LLVM_CONFIG_PATH
"
cd
$RUST_DIR
cargo fmt
--
--check
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment