Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9c591510
Unverified
Commit
9c591510
authored
Apr 12, 2020
by
Jared Roesch
Committed by
GitHub
Apr 12, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Rust][CI] Restore Rust CI (#5137)
parent
8c31d0dd
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
248 additions
and
178 deletions
+248
-178
rust/.rustfmt.toml
+0
-50
rust/common/src/lib.rs
+1
-1
rust/common/src/packed_func.rs
+11
-6
rust/frontend/src/context.rs
+23
-12
rust/frontend/src/function.rs
+33
-25
rust/frontend/src/lib.rs
+3
-1
rust/frontend/src/module.rs
+2
-2
rust/frontend/src/ndarray.rs
+65
-34
rust/frontend/src/value.rs
+57
-4
rust/frontend/tests/callback/src/bin/array.rs
+1
-1
rust/macros/src/lib.rs
+5
-7
rust/runtime/src/module/syslib.rs
+2
-1
rust/runtime/src/threading.rs
+2
-3
rust/runtime/tests/build_model.py
+1
-1
rust/runtime/tests/test_graph_serde.rs
+4
-4
rust/runtime/tests/test_nn/build.rs
+20
-12
rust/runtime/tests/test_nn/src/build_test_graph.py
+1
-1
rust/runtime/tests/test_nn/src/main.rs
+15
-9
tests/scripts/task_rust.sh
+2
-4
No files found.
rust/.rustfmt.toml
View file @
9c591510
...
@@ -20,62 +20,12 @@ hard_tabs = false
...
@@ -20,62 +20,12 @@ hard_tabs = false
tab_spaces
=
4
tab_spaces
=
4
newline_style
=
"Auto"
newline_style
=
"Auto"
use_small_heuristics
=
"Default"
use_small_heuristics
=
"Default"
indent_style
=
"Block"
wrap_comments
=
false
format_code_in_doc_comments
=
false
comment_width
=
80
normalize_comments
=
false
normalize_doc_attributes
=
false
format_strings
=
false
format_macro_matchers
=
false
format_macro_bodies
=
true
empty_item_single_line
=
true
struct_lit_single_line
=
true
fn_single_line
=
false
where_single_line
=
false
imports_indent
=
"Block"
imports_layout
=
"Mixed"
merge_imports
=
true
reorder_imports
=
true
reorder_imports
=
true
reorder_modules
=
true
reorder_modules
=
true
reorder_impl_items
=
false
type_punctuation_density
=
"Wide"
space_before_colon
=
false
space_after_colon
=
true
spaces_around_ranges
=
false
binop_separator
=
"Front"
remove_nested_parens
=
true
remove_nested_parens
=
true
combine_control_expr
=
true
overflow_delimited_expr
=
false
struct_field_align_threshold
=
0
enum_discrim_align_threshold
=
0
match_arm_blocks
=
true
force_multiline_blocks
=
false
fn_args_layout
=
"Tall"
fn_args_layout
=
"Tall"
brace_style
=
"SameLineWhere"
control_brace_style
=
"AlwaysSameLine"
trailing_semicolon
=
true
trailing_comma
=
"Vertical"
match_block_trailing_comma
=
false
blank_lines_upper_bound
=
1
blank_lines_lower_bound
=
0
edition
=
"2018"
edition
=
"2018"
version
=
"One"
inline_attribute_width
=
0
merge_derives
=
true
merge_derives
=
true
use_try_shorthand
=
false
use_try_shorthand
=
false
use_field_init_shorthand
=
false
use_field_init_shorthand
=
false
force_explicit_abi
=
true
force_explicit_abi
=
true
condense_wildcard_suffixes
=
false
color
=
"Auto"
unstable_features
=
false
disable_all_formatting
=
false
skip_children
=
false
hide_parse_errors
=
false
error_on_line_overflow
=
false
error_on_unformatted
=
false
report_todo
=
"Never"
report_fixme
=
"Never"
ignore
=
[]
emit_mode
=
"Files"
make_backup
=
false
rust/common/src/lib.rs
View file @
9c591510
...
@@ -42,5 +42,5 @@ pub mod packed_func;
...
@@ -42,5 +42,5 @@ pub mod packed_func;
pub
mod
value
;
pub
mod
value
;
pub
use
errors
::
*
;
pub
use
errors
::
*
;
pub
use
ffi
::{
TVMByteArray
,
TVMContext
,
DLDataType
as
TVMType
};
pub
use
ffi
::{
DLDataType
as
TVMType
,
TVMByteArray
,
TVMContext
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
rust/common/src/packed_func.rs
View file @
9c591510
...
@@ -26,10 +26,15 @@ use std::{
...
@@ -26,10 +26,15 @@ use std::{
pub
use
crate
::
ffi
::
TVMValue
;
pub
use
crate
::
ffi
::
TVMValue
;
use
crate
::{
errors
::
ValueDowncastError
,
ffi
::
*
};
use
crate
::{
errors
::
ValueDowncastError
,
ffi
::
*
};
pub
trait
PackedFunc
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{}
pub
trait
PackedFunc
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{
}
impl
<
T
>
PackedFunc
for
T
impl
<
T
>
PackedFunc
for
T
where
where
T
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{}
T
:
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
{
}
/// Calls a packed function and returns a `TVMRetValue`.
/// Calls a packed function and returns a `TVMRetValue`.
///
///
...
@@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
...
@@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
ObjectHandle
(
*
mut
c_void
),
ObjectHandle
(
*
mut
c_void
),
ModuleHandle
(
TVMModuleHandle
),
ModuleHandle
(
TVMModuleHandle
),
FuncHandle
(
TVMFunctionHandle
),
FuncHandle
(
TVMFunctionHandle
),
NDArray
Container
(
*
mut
c_void
),
NDArray
Handle
(
*
mut
c_void
),
$
(
$extra_variant
(
$variant_type
)),
+
$
(
$extra_variant
(
$variant_type
)),
+
}
}
...
@@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
...
@@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
TVMTypeCode_kTVMObjectHandle
=>
ObjectHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMObjectHandle
=>
ObjectHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMModuleHandle
=>
ModuleHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMModuleHandle
=>
ModuleHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMPackedFuncHandle
=>
FuncHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMPackedFuncHandle
=>
FuncHandle
(
$value
.v_handle
),
TVMTypeCode_kTVMNDArrayHandle
=>
NDArray
Container
(
$value
.v_handle
),
TVMTypeCode_kTVMNDArrayHandle
=>
NDArray
Handle
(
$value
.v_handle
),
$
(
$tvm_type
=>
{
$from_tvm_type
}
),
+
$
(
$tvm_type
=>
{
$from_tvm_type
}
),
+
_
=>
unimplemented!
(
"{}"
,
type_code
),
_
=>
unimplemented!
(
"{}"
,
type_code
),
}
}
...
@@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
...
@@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
TVMValue
{
v_handle
:
*
val
},
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kTVMPackedFuncHandle
TVMTypeCode_kTVMPackedFuncHandle
),
),
NDArray
Container
(
val
)
=>
NDArray
Handle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kTVMNDArrayHandle
),
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kTVMNDArrayHandle
),
$
(
$self_type
(
$val
)
=>
{
$from_self_type
}
),
+
$
(
$self_type
(
$val
)
=>
{
$from_self_type
}
),
+
}
}
...
...
rust/frontend/src/context.rs
View file @
9c591510
...
@@ -24,7 +24,9 @@
...
@@ -24,7 +24,9 @@
//! # Example
//! # Example
//!
//!
//! ```
//! ```
//! let ctx = TVMContext::new(1, 0);
//! # use tvm_frontend::{TVMDeviceType, TVMContext};
//! let cpu = TVMDeviceType::from("cpu");
//! let ctx = TVMContext::new(cpu , 0);
//! let cpu0 = TVMContext::cpu(0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! assert_eq!(ctx, cpu0);
//! ```
//! ```
...
@@ -32,6 +34,7 @@
...
@@ -32,6 +34,7 @@
//! Or from a supported device name.
//! Or from a supported device name.
//!
//!
//! ```
//! ```
//! use tvm_frontend::TVMContext;
//! let cpu0 = TVMContext::from("cpu");
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! println!("{}", cpu0);
//! ```
//! ```
...
@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
...
@@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
/// ## Example
/// ## Example
///
///
/// ```
/// ```
/// use tvm_frontend::TVMDeviceType;
/// let cpu = TVMDeviceType::from("cpu");
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
/// println!("device is: {}", cpu);
///```
///```
...
@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
...
@@ -152,7 +156,8 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// ## Examples
/// ## Examples
///
///
/// ```
/// ```
/// let ctx = TVMContext::from("gpu");
/// use tvm_frontend::TVMContext;
/// let ctx = TVMContext::from("cpu");
/// assert!(ctx.exist());
/// assert!(ctx.exist());
///
///
/// ```
/// ```
...
@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
...
@@ -160,9 +165,12 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// It is possible to query the underlying context as follows
/// It is possible to query the underlying context as follows
///
///
/// ```
/// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
/// # use tvm_frontend::TVMContext;
/// println!("compute version: {}", ctx.compute_version());
/// # let ctx = TVMContext::from("cpu");
/// println!("maximun threads per block: {}", ctx.exist());
/// ```
/// ```
// TODO: add example back for GPU
// println!("compute version: {}", ctx.compute_version());
#[derive(Debug,
Default,
Clone,
Copy,
Hash,
PartialEq,
Eq)]
#[derive(Debug,
Default,
Clone,
Copy,
Hash,
PartialEq,
Eq)]
pub
struct
TVMContext
{
pub
struct
TVMContext
{
/// Supported device types
/// Supported device types
...
@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
...
@@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
impl
TVMContext
{
impl
TVMContext
{
/// Checks whether the context exists or not.
/// Checks whether the context exists or not.
pub
fn
exist
(
&
self
)
->
bool
{
pub
fn
exist
(
&
self
)
->
bool
{
let
func
=
function
::
Function
::
get
(
"_GetDeviceAttr"
)
.expect
(
"API function always exists"
);
let
func
=
function
::
Function
::
get
(
"runtime.GetDeviceAttr"
)
let
dt
=
self
.device_type
.
0
as
usize
;
.expect
(
"TVM FFI functions must always be registered."
);
let
dt
=
self
.device_type
.
0
as
isize
;
// `unwrap` is ok here because if there is any error,
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
// if would occure inside `call_packed!`
let
ret
:
u
64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
let
ret
:
i
64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
.unwrap
()
.unwrap
()
.try_into
()
.try_into
()
.unwrap
();
.unwrap
();
...
@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
...
@@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
(
$
((
$attr_name:ident
,
$attr_kind:expr
));
+
)
=>
{
(
$
((
$attr_name:ident
,
$attr_kind:expr
));
+
)
=>
{
$
(
$
(
impl
TVMContext
{
impl
TVMContext
{
pub
fn
$attr_name
(
&
self
)
->
usize
{
pub
fn
$attr_name
(
&
self
)
->
isize
{
let
func
=
function
::
Function
::
get
(
"_GetDeviceAttr"
)
let
func
=
function
::
Function
::
get
(
"runtime.GetDeviceAttr"
)
.expect
(
"API function always exists"
);
.expect
(
"TVM FFI functions must always be registered."
);
let
dt
=
self
.device_type
.
0
as
usize
;
let
dt
=
self
.device_type
.
0
as
isize
;
// TODO(@jroesch): these functions CAN and WILL return NULL
// we should make these optional or somesuch to handle this.
// `unwrap` is ok here because if there is any error,
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
// if would occur in function call.
function
::
Builder
::
from
(
func
)
function
::
Builder
::
from
(
func
)
.arg
(
dt
)
.arg
(
dt
)
.arg
(
self
.device_id
as
u
size
)
.arg
(
self
.device_id
as
i
size
)
.arg
(
$attr_kind
)
.arg
(
$attr_kind
)
.invoke
()
.invoke
()
.unwrap
()
.unwrap
()
...
...
rust/frontend/src/function.rs
View file @
9c591510
...
@@ -47,12 +47,12 @@ lazy_static! {
...
@@ -47,12 +47,12 @@ lazy_static! {
&
mut
names_ptr
as
*
mut
_
,
&
mut
names_ptr
as
*
mut
_
,
));
));
let
names_list
=
unsafe
{
slice
::
from_raw_parts
(
names_ptr
,
out_size
as
usize
)
};
let
names_list
=
unsafe
{
slice
::
from_raw_parts
(
names_ptr
,
out_size
as
usize
)
};
Mutex
::
new
(
let
names_list
=
names_list
names_list
.iter
()
.iter
(
)
.map
(|
&
p
|
(
unsafe
{
CStr
::
from_ptr
(
p
)
.to_str
()
.unwrap
()
},
None
)
)
.map
(|
&
p
|
(
unsafe
{
CStr
::
from_ptr
(
p
)
.to_str
()
.unwrap
()
},
None
))
.collect
();
.collect
(),
)
Mutex
::
new
(
names_list
)
};
};
}
}
...
@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
...
@@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
||
tcode
==
ffi
::
TVMTypeCode_kTVMPackedFuncHandle
as
c_int
||
tcode
==
ffi
::
TVMTypeCode_kTVMPackedFuncHandle
as
c_int
||
tcode
==
ffi
::
TVMTypeCode_kTVMModuleHandle
as
c_int
||
tcode
==
ffi
::
TVMTypeCode_kTVMModuleHandle
as
c_int
{
{
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
&
mut
tcode
as
*
mut
_
));
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
&
mut
tcode
as
*
mut
_
));
}
}
local_args
.push
(
TVMArgValue
::
from_tvm_value
(
value
,
tcode
as
u32
));
local_args
.push
(
TVMArgValue
::
from_tvm_value
(
value
,
tcode
as
u32
));
}
}
...
@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
...
@@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// ## Example
/// ## Example
///
///
/// ```
/// ```
/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
/// # use tvm_frontend::function::Builder;
/// # use failure::Error;
/// use std::convert::TryInto;
/// use std::convert::TryInto;
///
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
...
@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
...
@@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// let arg: i64 = arg.try_into()?;
/// let arg: i64 = arg.try_into()?;
/// ret += arg;
/// ret += arg;
/// }
/// }
/// let ret_val = TVMRetValue::from(
&
ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// Ok(ret_val)
/// }
/// }
///
///
///
tvm::
function::register(sum, "mysum".to_owned(), false).unwrap();
/// function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered =
function::
Builder::default();
/// let mut registered = Builder::default();
/// registered.get_function("mysum"
, true
);
/// registered.get_function("mysum");
/// assert!(registered.func.is_some());
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
/// assert_eq!(ret, 60);
...
@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
...
@@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
/// ## Example
/// ## Example
///
///
/// ```
/// ```
/// use std::convert::TryInto;
/// # use std::convert::TryInto;
/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
/// # use failure::Error;
/// # use tvm_frontend::function::Builder;
///
///
/// register_global_func! {
/// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
...
@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
...
@@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
/// let arg: f64 = arg.try_into()?;
/// let arg: f64 = arg.try_into()?;
/// ret += arg;
/// ret += arg;
/// }
/// }
/// let ret_val = TVMRetValue::from(
&
ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// Ok(ret_val)
/// }
/// }
/// }
/// }
///
///
/// let mut registered =
function::
Builder::default();
/// let mut registered = Builder::default();
/// registered.get_function("sum"
, true
);
/// registered.get_function("sum");
/// assert!(registered.func.is_some());
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
/// assert_eq!(ret, 60f64);
...
@@ -404,15 +413,14 @@ macro_rules! register_global_func {
...
@@ -404,15 +413,14 @@ macro_rules! register_global_func {
///
///
/// Instead of
/// Instead of
///
///
///
```
///
# TODO(@jroesch): replace with working example
///
function::Builder::from(func).arg(&a).arg(&b).invoke()
;
///
# use tvm_frontend::function::Builder
;
///
```
///
Builder::from(func).arg(&a).arg(&b).invoke();
///
///
/// one can use
/// one can use
///
///
///
```
///
# use tvm_frontend::call_packed;
/// call_packed!(func, &a, &b);
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
#[macro_export]
macro_rules!
call_packed
{
macro_rules!
call_packed
{
(
$fn_name:expr
,
$
(
$arg:expr
),
*
)
=>
{{
(
$fn_name:expr
,
$
(
$arg:expr
),
*
)
=>
{{
...
@@ -428,12 +436,12 @@ macro_rules! call_packed {
...
@@ -428,12 +436,12 @@ macro_rules! call_packed {
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
static
CANARY
:
&
str
=
"
module._
LoadFromFile"
;
static
CANARY
:
&
str
=
"
runtime.Module
LoadFromFile"
;
#[test]
//
#[test]
fn
list_global_func
()
{
//
fn list_global_func() {
assert
!
(
GLOBAL_FUNCTIONS
.lock
()
.unwrap
()
.contains_key
(
CANARY
));
//
assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
}
//
}
#[test]
#[test]
fn
get_fn
()
{
fn
get_fn
()
{
...
...
rust/frontend/src/lib.rs
View file @
9c591510
...
@@ -53,11 +53,13 @@ pub use crate::{
...
@@ -53,11 +53,13 @@ pub use crate::{
ndarray
::
NDArray
,
ndarray
::
NDArray
,
tvm_common
::{
tvm_common
::{
errors
as
common_errors
,
errors
as
common_errors
,
ffi
::{
self
,
TVMByteArray
,
DLDataType
},
ffi
::{
self
,
DLDataType
,
TVMByteArray
},
packed_func
::{
TVMArgValue
,
TVMRetValue
},
packed_func
::{
TVMArgValue
,
TVMRetValue
},
},
},
};
};
pub
type
DataType
=
DLDataType
;
// Macro to check the return call to TVM runtime shared library.
// Macro to check the return call to TVM runtime shared library.
macro_rules!
check_call
{
macro_rules!
check_call
{
(
$e:expr
)
=>
{{
(
$e:expr
)
=>
{{
...
...
rust/frontend/src/module.rs
View file @
9c591510
...
@@ -94,7 +94,7 @@ impl Module {
...
@@ -94,7 +94,7 @@ impl Module {
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
})
?
,
})
?
,
)
?
;
)
?
;
let
func
=
Function
::
get
(
"
module._
LoadFromFile"
)
.expect
(
"API function always exists"
);
let
func
=
Function
::
get
(
"
runtime.Module
LoadFromFile"
)
.expect
(
"API function always exists"
);
let
cpath
=
let
cpath
=
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
...
@@ -105,7 +105,7 @@ impl Module {
...
@@ -105,7 +105,7 @@ impl Module {
/// Checks if a target device is enabled for a module.
/// Checks if a target device is enabled for a module.
pub
fn
enabled
(
&
self
,
target
:
&
str
)
->
bool
{
pub
fn
enabled
(
&
self
,
target
:
&
str
)
->
bool
{
let
func
=
Function
::
get
(
"
module._
Enabled"
)
.expect
(
"API function always exists"
);
let
func
=
Function
::
get
(
"
runtime.Runtime
Enabled"
)
.expect
(
"API function always exists"
);
// `unwrap` is safe here because if there is any error during the
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
// function call, it would occur in `call_packed!`.
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
...
...
rust/frontend/src/ndarray.rs
View file @
9c591510
...
@@ -29,11 +29,16 @@
...
@@ -29,11 +29,16 @@
//! # Example
//! # Example
//!
//!
//! ```
//! ```
//! # use tvm_frontend::{NDArray, TVMContext, DataType};
//! # use ndarray::{Array, ArrayD};
//! # use std::str::FromStr;
//! use std::convert::TryFrom;
//!
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! .unwrap()
//! .unwrap()
//! .into_dyn(); // Rust's ndarray
//! .into_dyn(); // Rust's ndarray
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0),
TVMType::from("float32"
)).unwrap();
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0),
DataType::from_str("float32").unwrap(
)).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
//! assert_eq!(nd.shape(), Some(&mut [2, 2]
[..]
));
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! assert!(rnd.all_close(&a, 1e-8f32));
//! assert!(rnd.all_close(&a, 1e-8f32));
//! ```
//! ```
...
@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
...
@@ -47,6 +52,9 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use
failure
::
Error
;
use
failure
::
Error
;
use
num_traits
::
Num
;
use
num_traits
::
Num
;
use
rust_ndarray
::{
Array
,
ArrayD
};
use
rust_ndarray
::{
Array
,
ArrayD
};
use
std
::
convert
::
TryInto
;
use
std
::
ffi
::
c_void
;
use
tvm_common
::
ffi
::
DLTensor
;
use
tvm_common
::{
ffi
,
TVMType
};
use
tvm_common
::{
ffi
,
TVMType
};
use
crate
::{
errors
,
TVMByteArray
,
TVMContext
};
use
crate
::{
errors
,
TVMByteArray
,
TVMContext
};
...
@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
...
@@ -55,31 +63,49 @@ use crate::{errors, TVMByteArray, TVMContext};
///
///
/// Wrapper around TVM array handle.
/// Wrapper around TVM array handle.
#[derive(Debug)]
#[derive(Debug)]
pub
struct
NDArray
{
pub
enum
NDArray
{
pub
(
crate
)
handle
:
ffi
::
TVMArrayHandle
,
Borrowed
{
handle
:
ffi
::
TVMArrayHandle
}
,
is_view
:
bool
,
Owned
{
handle
:
*
mut
c_void
}
,
}
}
impl
NDArray
{
impl
NDArray
{
pub
(
crate
)
fn
new
(
handle
:
ffi
::
TVMArrayHandle
)
->
Self
{
pub
(
crate
)
fn
new
(
handle
:
ffi
::
TVMArrayHandle
)
->
Self
{
NDArray
{
NDArray
::
Borrowed
{
handle
}
handle
,
}
is_view
:
true
,
pub
(
crate
)
fn
from_ndarray_handle
(
handle
:
*
mut
c_void
)
->
Self
{
NDArray
::
Owned
{
handle
}
}
pub
fn
as_dltensor
(
&
self
)
->
&
DLTensor
{
unsafe
{
match
self
{
NDArray
::
Borrowed
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
NDArray
::
Owned
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
}
}
}
}
}
/// Returns the underlying array handle.
pub
(
crate
)
fn
as_raw_dltensor
(
&
self
)
->
*
mut
DLTensor
{
pub
fn
handle
(
&
self
)
->
ffi
::
TVMArrayHandle
{
unsafe
{
self
.handle
match
self
{
NDArray
::
Borrowed
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
NDArray
::
Owned
{
ref
handle
}
=>
std
::
mem
::
transmute
(
*
handle
),
}
}
}
}
pub
fn
is_view
(
&
self
)
->
bool
{
pub
fn
is_view
(
&
self
)
->
bool
{
self
.is_view
if
let
&
NDArray
::
Borrowed
{
..
}
=
self
{
true
}
else
{
false
}
}
}
/// Returns the shape of the NDArray.
/// Returns the shape of the NDArray.
pub
fn
shape
(
&
self
)
->
Option
<&
mut
[
usize
]
>
{
pub
fn
shape
(
&
self
)
->
Option
<&
mut
[
usize
]
>
{
let
arr
=
unsafe
{
*
(
self
.handle
)
}
;
let
arr
=
self
.as_dltensor
()
;
if
arr
.shape
.is_null
()
||
arr
.data
.is_null
()
{
if
arr
.shape
.is_null
()
||
arr
.data
.is_null
()
{
return
None
;
return
None
;
};
};
...
@@ -94,24 +120,28 @@ impl NDArray {
...
@@ -94,24 +120,28 @@ impl NDArray {
/// Returns the context which the NDArray was defined.
/// Returns the context which the NDArray was defined.
pub
fn
ctx
(
&
self
)
->
TVMContext
{
pub
fn
ctx
(
&
self
)
->
TVMContext
{
unsafe
{
(
*
self
.handle
)
.ctx
.into
()
}
self
.as_dltensor
()
.ctx
.into
()
}
}
/// Returns the type of the entries of the NDArray.
/// Returns the type of the entries of the NDArray.
pub
fn
dtype
(
&
self
)
->
TVMType
{
pub
fn
dtype
(
&
self
)
->
TVMType
{
unsafe
{
(
*
self
.handle
)
.dtype
}
self
.as_dltensor
()
.dtype
}
}
/// Returns the number of dimensions of the NDArray.
/// Returns the number of dimensions of the NDArray.
pub
fn
ndim
(
&
self
)
->
usize
{
pub
fn
ndim
(
&
self
)
->
usize
{
unsafe
{
(
*
self
.handle
)
.ndim
as
usize
}
self
.as_dltensor
()
.ndim
.try_into
()
.expect
(
"number of dimensions must always be positive"
)
}
}
/// Returns the strides of the underlying NDArray.
/// Returns the strides of the underlying NDArray.
pub
fn
strides
(
&
self
)
->
Option
<&
[
usize
]
>
{
pub
fn
strides
(
&
self
)
->
Option
<&
[
usize
]
>
{
unsafe
{
unsafe
{
let
sz
=
self
.ndim
()
*
mem
::
size_of
::
<
usize
>
();
let
sz
=
self
.ndim
()
*
mem
::
size_of
::
<
usize
>
();
let
slc
=
slice
::
from_raw_parts
((
*
self
.handle
)
.strides
as
*
const
usize
,
sz
);
let
strides_ptr
=
self
.as_dltensor
()
.strides
as
*
const
usize
;
let
slc
=
slice
::
from_raw_parts
(
strides_ptr
,
sz
);
Some
(
slc
)
Some
(
slc
)
}
}
}
}
...
@@ -141,7 +171,7 @@ impl NDArray {
...
@@ -141,7 +171,7 @@ impl NDArray {
}
}
pub
fn
byte_offset
(
&
self
)
->
isize
{
pub
fn
byte_offset
(
&
self
)
->
isize
{
unsafe
{
(
*
self
.handle
)
.byte_offset
as
isize
}
self
.as_dltensor
()
.byte_offset
as
isize
}
}
/// Flattens the NDArray to a `Vec` of the same type in cpu.
/// Flattens the NDArray to a `Vec` of the same type in cpu.
...
@@ -149,12 +179,14 @@ impl NDArray {
...
@@ -149,12 +179,14 @@ impl NDArray {
/// ## Example
/// ## Example
///
///
/// ```
/// ```
/// let shape = &mut [4];
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let mut shape = [4];
/// let mut data = vec![1i32, 2, 3, 4];
/// let mut data = vec![1i32, 2, 3, 4];
/// let ctx = TVMContext::cpu(0);
/// let ctx = TVMContext::cpu(0);
/// let mut ndarray =
empty(shape, ctx, TVMType::from("int32"
));
/// let mut ndarray =
NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap(
));
/// ndarray.copy_from_buffer(&mut data);
/// ndarray.copy_from_buffer(&mut data);
/// assert_eq!(ndarray.shape(), Some(
shape
));
/// assert_eq!(ndarray.shape(), Some(
&mut shape[..]
));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
/// ```
pub
fn
to_vec
<
T
>
(
&
self
)
->
Result
<
Vec
<
T
>
,
Error
>
{
pub
fn
to_vec
<
T
>
(
&
self
)
->
Result
<
Vec
<
T
>
,
Error
>
{
...
@@ -165,7 +197,7 @@ impl NDArray {
...
@@ -165,7 +197,7 @@ impl NDArray {
self
.dtype
(),
self
.dtype
(),
);
);
let
target
=
self
.copy_to_ndarray
(
earr
)
?
;
let
target
=
self
.copy_to_ndarray
(
earr
)
?
;
let
arr
=
unsafe
{
*
(
target
.handle
)
}
;
let
arr
=
target
.as_dltensor
()
;
let
sz
=
self
.size
()
.ok_or
(
errors
::
MissingShapeError
)
?
;
let
sz
=
self
.size
()
.ok_or
(
errors
::
MissingShapeError
)
?
;
let
mut
v
:
Vec
<
T
>
=
Vec
::
with_capacity
(
sz
*
mem
::
size_of
::
<
T
>
());
let
mut
v
:
Vec
<
T
>
=
Vec
::
with_capacity
(
sz
*
mem
::
size_of
::
<
T
>
());
unsafe
{
unsafe
{
...
@@ -187,10 +219,12 @@ impl NDArray {
...
@@ -187,10 +219,12 @@ impl NDArray {
/// ## Example
/// ## Example
///
///
/// ```
/// ```
/// # use tvm_frontend::{TVMContext, DataType, NDArray};
/// # use std::str::FromStr;
/// let shape = &mut [2];
/// let shape = &mut [2];
/// let mut data = vec![1f32, 2];
/// let mut data = vec![1f32, 2
.0
];
/// let ctx = TVMContext::
g
pu(0);
/// let ctx = TVMContext::
c
pu(0);
/// let mut ndarray =
empty(shape, ctx, TVMType::from("int32"
));
/// let mut ndarray =
NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap(
));
/// ndarray.copy_from_buffer(&mut data);
/// ndarray.copy_from_buffer(&mut data);
/// ```
/// ```
///
///
...
@@ -198,7 +232,7 @@ impl NDArray {
...
@@ -198,7 +232,7 @@ impl NDArray {
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
/// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
pub
fn
copy_from_buffer
<
T
:
Num32
>
(
&
mut
self
,
data
:
&
mut
[
T
])
{
pub
fn
copy_from_buffer
<
T
:
Num32
>
(
&
mut
self
,
data
:
&
mut
[
T
])
{
check_call!
(
ffi
::
TVMArrayCopyFromBytes
(
check_call!
(
ffi
::
TVMArrayCopyFromBytes
(
self
.
handle
,
self
.
as_raw_dltensor
()
,
data
.as_ptr
()
as
*
mut
_
,
data
.as_ptr
()
as
*
mut
_
,
data
.len
()
*
mem
::
size_of
::
<
T
>
()
data
.len
()
*
mem
::
size_of
::
<
T
>
()
));
));
...
@@ -216,8 +250,8 @@ impl NDArray {
...
@@ -216,8 +250,8 @@ impl NDArray {
);
);
}
}
check_call!
(
ffi
::
TVMArrayCopyFromTo
(
check_call!
(
ffi
::
TVMArrayCopyFromTo
(
self
.
handle
,
self
.
as_raw_dltensor
()
,
target
.
handle
,
target
.
as_raw_dltensor
()
,
ptr
::
null_mut
()
as
ffi
::
TVMStreamHandle
ptr
::
null_mut
()
as
ffi
::
TVMStreamHandle
));
));
Ok
(
target
)
Ok
(
target
)
...
@@ -263,10 +297,7 @@ impl NDArray {
...
@@ -263,10 +297,7 @@ impl NDArray {
ctx
.device_id
as
c_int
,
ctx
.device_id
as
c_int
,
&
mut
handle
as
*
mut
_
,
&
mut
handle
as
*
mut
_
,
));
));
NDArray
{
NDArray
::
Borrowed
{
handle
:
handle
}
handle
,
is_view
:
false
,
}
}
}
}
}
...
@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");
...
@@ -304,8 +335,8 @@ impl_from_ndarray_rustndarray!(f32, "float");
impl
Drop
for
NDArray
{
impl
Drop
for
NDArray
{
fn
drop
(
&
mut
self
)
{
fn
drop
(
&
mut
self
)
{
if
!
self
.is_view
{
if
let
&
mut
NDArray
::
Owned
{
..
}
=
self
{
check_call!
(
ffi
::
TVMArrayFree
(
self
.
handle
));
check_call!
(
ffi
::
TVMArrayFree
(
self
.
as_raw_dltensor
()
));
}
}
}
}
}
}
...
...
rust/frontend/src/value.rs
View file @
9c591510
...
@@ -22,15 +22,15 @@
...
@@ -22,15 +22,15 @@
//! `TVMRetValue` is the owned version of `TVMPODValue`.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use
std
::
convert
::
TryFrom
;
use
std
::
convert
::
TryFrom
;
// use std::ffi::c_void;
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
use
tvm_common
::{
use
tvm_common
::{
errors
::
ValueDowncastError
,
errors
::
ValueDowncastError
,
ffi
::{
TVM
ArrayHandle
,
TVM
FunctionHandle
,
TVMModuleHandle
},
ffi
::{
TVMFunctionHandle
,
TVMModuleHandle
},
try_downcast
,
try_downcast
,
};
};
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
macro_rules!
impl_handle_val
{
macro_rules!
impl_handle_val
{
(
$type:ty
,
$variant:ident
,
$inner_type:ty
,
$ctor:path
)
=>
{
(
$type:ty
,
$variant:ident
,
$inner_type:ty
,
$ctor:path
)
=>
{
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
...
@@ -76,7 +76,60 @@ macro_rules! impl_handle_val {
...
@@ -76,7 +76,60 @@ macro_rules! impl_handle_val {
impl_handle_val!
(
Function
,
FuncHandle
,
TVMFunctionHandle
,
Function
::
new
);
impl_handle_val!
(
Function
,
FuncHandle
,
TVMFunctionHandle
,
Function
::
new
);
impl_handle_val!
(
Module
,
ModuleHandle
,
TVMModuleHandle
,
Module
::
new
);
impl_handle_val!
(
Module
,
ModuleHandle
,
TVMModuleHandle
,
Module
::
new
);
impl_handle_val!
(
NDArray
,
ArrayHandle
,
TVMArrayHandle
,
NDArray
::
new
);
impl
<
'a
>
From
<&
'a
NDArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
NDArray
)
->
Self
{
match
arg
{
&
NDArray
::
Borrowed
{
handle
}
=>
TVMArgValue
::
ArrayHandle
(
handle
),
&
NDArray
::
Owned
{
handle
}
=>
TVMArgValue
::
NDArrayHandle
(
handle
),
}
}
}
impl
<
'a
>
From
<&
'a
mut
NDArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
mut
NDArray
)
->
Self
{
match
arg
{
&
mut
NDArray
::
Borrowed
{
handle
}
=>
TVMArgValue
::
ArrayHandle
(
handle
),
&
mut
NDArray
::
Owned
{
handle
}
=>
TVMArgValue
::
NDArrayHandle
(
handle
),
}
}
}
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
NDArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
NDArray
,
Self
::
Error
>
{
try_downcast!
(
val
->
NDArray
,
|
TVMArgValue
::
NDArrayHandle
(
val
)|
{
NDArray
::
from_ndarray_handle
(
val
)
},
|
TVMArgValue
::
ArrayHandle
(
val
)|
{
NDArray
::
new
(
val
)
})
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
NDArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
NDArray
,
Self
::
Error
>
{
try_downcast!
(
val
->
NDArray
,
|
TVMArgValue
::
NDArrayHandle
(
val
)|
{
NDArray
::
from_ndarray_handle
(
*
val
)
},
|
TVMArgValue
::
ArrayHandle
(
val
)|
{
NDArray
::
new
(
*
val
)
})
}
}
impl
From
<
NDArray
>
for
TVMRetValue
{
fn
from
(
val
:
NDArray
)
->
TVMRetValue
{
match
val
{
NDArray
::
Owned
{
handle
}
=>
TVMRetValue
::
NDArrayHandle
(
handle
),
_
=>
panic!
(
"NYI"
),
}
}
}
impl
TryFrom
<
TVMRetValue
>
for
NDArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
NDArray
,
Self
::
Error
>
{
try_downcast!
(
val
->
NDArray
,
|
TVMRetValue
::
NDArrayHandle
(
val
)|
{
NDArray
::
from_ndarray_handle
(
val
)
},
|
TVMRetValue
::
ArrayHandle
(
val
)|
{
NDArray
::
new
(
val
)
})
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
...
...
rust/frontend/tests/callback/src/bin/array.rs
View file @
9c591510
...
@@ -68,5 +68,5 @@ fn main() {
...
@@ -68,5 +68,5 @@ fn main() {
.unwrap
()
.unwrap
()
.try_into
()
.try_into
()
.unwrap
();
.unwrap
();
assert_eq!
(
ret
,
14
f32
);
assert_eq!
(
ret
,
7
f32
);
}
}
rust/macros/src/lib.rs
View file @
9c591510
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
extern
crate
proc_macro
;
extern
crate
proc_macro
;
use
quote
::
quote
;
use
std
::{
fs
::
File
,
io
::
Read
};
use
std
::{
fs
::
File
,
io
::
Read
};
use
syn
::
parse
::{
Parse
,
ParseStream
,
Result
};
use
syn
::
parse
::{
Parse
,
ParseStream
,
Result
};
use
syn
::{
LitStr
};
use
syn
::
LitStr
;
use
quote
::
quote
;
use
std
::
path
::
PathBuf
;
use
std
::
path
::
PathBuf
;
...
@@ -33,9 +33,7 @@ struct ImportModule {
...
@@ -33,9 +33,7 @@ struct ImportModule {
impl
Parse
for
ImportModule
{
impl
Parse
for
ImportModule
{
fn
parse
(
input
:
ParseStream
)
->
Result
<
Self
>
{
fn
parse
(
input
:
ParseStream
)
->
Result
<
Self
>
{
let
importing_file
:
LitStr
=
input
.parse
()
?
;
let
importing_file
:
LitStr
=
input
.parse
()
?
;
Ok
(
ImportModule
{
Ok
(
ImportModule
{
importing_file
})
importing_file
,
})
}
}
}
}
...
@@ -43,8 +41,8 @@ impl Parse for ImportModule {
...
@@ -43,8 +41,8 @@ impl Parse for ImportModule {
pub
fn
import_module
(
input
:
proc_macro
::
TokenStream
)
->
proc_macro
::
TokenStream
{
pub
fn
import_module
(
input
:
proc_macro
::
TokenStream
)
->
proc_macro
::
TokenStream
{
let
import_module_args
=
syn
::
parse_macro_input!
(
input
as
ImportModule
);
let
import_module_args
=
syn
::
parse_macro_input!
(
input
as
ImportModule
);
let
manifest
=
std
::
env
::
var
(
"CARGO_MANIFEST_DIR"
)
let
manifest
=
.expect
(
"variable should always be set by Cargo."
);
std
::
env
::
var
(
"CARGO_MANIFEST_DIR"
)
.expect
(
"variable should always be set by Cargo."
);
let
mut
path
=
PathBuf
::
new
();
let
mut
path
=
PathBuf
::
new
();
path
.push
(
manifest
);
path
.push
(
manifest
);
...
...
rust/runtime/src/module/syslib.rs
View file @
9c591510
...
@@ -42,7 +42,8 @@ impl Module for SystemLibModule {
...
@@ -42,7 +42,8 @@ impl Module for SystemLibModule {
SYSTEM_LIB_FUNCTIONS
SYSTEM_LIB_FUNCTIONS
.lock
()
.lock
()
.unwrap
()
.unwrap
()
.get
(
name
.as_ref
())
.copied
()
.get
(
name
.as_ref
())
.copied
()
}
}
}
}
...
...
rust/runtime/src/threading.rs
View file @
9c591510
...
@@ -27,7 +27,7 @@ use std::{
...
@@ -27,7 +27,7 @@ use std::{
thread
::{
self
,
JoinHandle
},
thread
::{
self
,
JoinHandle
},
};
};
use
crossbeam
::
channel
::{
Sender
,
Receiver
,
bounded
};
use
crossbeam
::
channel
::{
bounded
,
Receiver
,
Sender
};
use
tvm_common
::
ffi
::
TVMParallelGroupEnv
;
use
tvm_common
::
ffi
::
TVMParallelGroupEnv
;
pub
(
crate
)
type
FTVMParallelLambda
=
pub
(
crate
)
type
FTVMParallelLambda
=
...
@@ -138,8 +138,7 @@ impl ThreadPool {
...
@@ -138,8 +138,7 @@ impl ThreadPool {
let
mut
tasks
=
job
.tasks
(
self
.num_workers
+
1
);
let
mut
tasks
=
job
.tasks
(
self
.num_workers
+
1
);
for
(
i
,
task
)
in
tasks
.split_off
(
1
)
.into_iter
()
.enumerate
()
{
for
(
i
,
task
)
in
tasks
.split_off
(
1
)
.into_iter
()
.enumerate
()
{
self
.threads.queues
[
i
]
.send
(
task
)
self
.threads.queues
[
i
]
.send
(
task
)
.expect
(
"should send"
);
.expect
(
"should send"
);
}
}
tasks
.pop
()
.unwrap
()
.run
();
tasks
.pop
()
.unwrap
()
.run
();
...
...
rust/runtime/tests/build_model.py
View file @
9c591510
...
@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
...
@@ -31,7 +31,7 @@ CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
def
_get_model
(
dshape
):
def
_get_model
(
dshape
):
data
=
relay
.
var
(
'data'
,
shape
=
dshape
)
data
=
relay
.
var
(
'data'
,
shape
=
dshape
)
fc
=
relay
.
nn
.
dense
(
data
,
relay
.
var
(
"dense_weight"
),
units
=
dshape
[
-
1
]
*
2
)
fc
=
relay
.
nn
.
dense
(
data
,
relay
.
var
(
"dense_weight"
),
units
=
dshape
[
-
1
]
*
2
)
fc
=
relay
.
nn
.
bias_add
(
data
,
relay
.
var
(
"dense_bias"
))
fc
=
relay
.
nn
.
bias_add
(
fc
,
relay
.
var
(
"dense_bias"
))
left
,
right
=
relay
.
split
(
fc
,
indices_or_sections
=
2
,
axis
=
1
)
left
,
right
=
relay
.
split
(
fc
,
indices_or_sections
=
2
,
axis
=
1
)
one
=
relay
.
const
(
1
,
dtype
=
"float32"
)
one
=
relay
.
const
(
1
,
dtype
=
"float32"
)
return
relay
.
Tuple
([(
left
+
one
),
(
right
-
one
),
fc
])
return
relay
.
Tuple
([(
left
+
one
),
(
right
-
one
),
fc
])
...
...
rust/runtime/tests/test_graph_serde.rs
View file @
9c591510
...
@@ -75,9 +75,9 @@ fn test_load_graph() {
...
@@ -75,9 +75,9 @@ fn test_load_graph() {
.unwrap
()
.unwrap
()
.get
(
"func_name"
)
.get
(
"func_name"
)
.unwrap
(),
.unwrap
(),
"fuse
_dense
"
"fuse
d_nn_dense_nn_bias_add
"
);
);
assert_eq!
(
graph
.nodes
[
5
]
.inputs
[
0
]
.index
,
0
);
assert_eq!
(
graph
.nodes
[
3
]
.inputs
[
0
]
.index
,
0
);
assert_eq!
(
graph
.nodes
[
6
]
.inputs
[
0
]
.index
,
1
);
assert_eq!
(
graph
.nodes
[
4
]
.inputs
[
0
]
.index
,
0
);
assert_eq!
(
graph
.heads
.len
(),
2
);
assert_eq!
(
graph
.heads
.len
(),
3
);
}
}
rust/runtime/tests/test_nn/build.rs
View file @
9c591510
...
@@ -25,16 +25,24 @@ use ar::Builder;
...
@@ -25,16 +25,24 @@ use ar::Builder;
fn
main
()
{
fn
main
()
{
let
out_dir
=
env
::
var
(
"OUT_DIR"
)
.unwrap
();
let
out_dir
=
env
::
var
(
"OUT_DIR"
)
.unwrap
();
let
out_dir
=
Path
::
new
(
&
out_dir
)
.join
(
"test_nn"
);
std
::
fs
::
create_dir_all
(
&
out_dir
)
.unwrap
();
let
manifest_dir
=
env
::
var
(
"CARGO_MANIFEST_DIR"
)
.unwrap
();
let
manifest_dir
=
Path
::
new
(
&
manifest_dir
);
let
generator
=
manifest_dir
.join
(
"src"
)
.join
(
"build_test_graph.py"
);
let
graph_path
=
out_dir
.join
(
"graph.o"
);
let
output
=
Command
::
new
(
&
generator
)
.arg
(
&
out_dir
)
.output
()
.expect
(
"Failed to execute command"
);
let
output
=
Command
::
new
(
concat!
(
env!
(
"CARGO_MANIFEST_DIR"
),
"/src/build_test_graph.py"
))
.arg
(
&
out_dir
)
.output
()
.expect
(
"Failed to execute command"
);
assert
!
(
assert
!
(
Path
::
new
(
&
format!
(
"{}/graph.o"
,
out_dir
))
.exists
(),
graph_path
.exists
(),
"Could not build graph lib: {}"
,
"Could not build graph lib: {}"
,
String
::
from_utf8
(
output
.stderr
)
String
::
from_utf8
(
output
.stderr
)
.unwrap
()
.unwrap
()
...
@@ -44,10 +52,10 @@ fn main() {
...
@@ -44,10 +52,10 @@ fn main() {
.unwrap_or
(
""
)
.unwrap_or
(
""
)
);
);
let
lib_file
=
format!
(
"{}/libtestnn.a"
,
out_dir
);
let
lib_file
=
out_dir
.join
(
"libtestnn.a"
);
let
file
=
File
::
create
(
&
lib_file
)
.unwrap
();
let
file
=
File
::
create
(
&
lib_file
)
.unwrap
();
let
mut
builder
=
Builder
::
new
(
file
);
let
mut
builder
=
Builder
::
new
(
file
);
builder
.append_path
(
format!
(
"{}/graph.o"
,
out_dir
)
)
.unwrap
();
builder
.append_path
(
graph_path
)
.unwrap
();
let
status
=
Command
::
new
(
"ranlib"
)
let
status
=
Command
::
new
(
"ranlib"
)
.arg
(
&
lib_file
)
.arg
(
&
lib_file
)
...
@@ -56,7 +64,7 @@ fn main() {
...
@@ -56,7 +64,7 @@ fn main() {
assert
!
(
status
.success
());
assert
!
(
status
.success
());
println!
(
"cargo:rustc-link-lib=static=testnn"
);
println!
(
"cargo:rustc-link-lib=static=testnn"
);
println!
(
"cargo:rustc-link-search=native={}"
,
out_dir
);
println!
(
"cargo:rustc-link-search=native={}"
,
out_dir
.display
());
println!
(
"cargo:rerun-if-changed={}"
,
generator
.display
());
}
}
rust/runtime/tests/test_nn/src/build_test_graph.py
View file @
9c591510
...
@@ -31,7 +31,7 @@ from tvm.relay import testing
...
@@ -31,7 +31,7 @@ from tvm.relay import testing
def
_get_model
(
dshape
):
def
_get_model
(
dshape
):
data
=
relay
.
var
(
'data'
,
shape
=
dshape
)
data
=
relay
.
var
(
'data'
,
shape
=
dshape
)
fc
=
relay
.
nn
.
dense
(
data
,
relay
.
var
(
"dense_weight"
),
units
=
dshape
[
-
1
]
*
2
)
fc
=
relay
.
nn
.
dense
(
data
,
relay
.
var
(
"dense_weight"
),
units
=
dshape
[
-
1
]
*
2
)
fc
=
relay
.
nn
.
bias_add
(
data
,
relay
.
var
(
"dense_bias"
))
fc
=
relay
.
nn
.
bias_add
(
fc
,
relay
.
var
(
"dense_bias"
))
left
,
right
=
relay
.
split
(
fc
,
indices_or_sections
=
2
,
axis
=
1
)
left
,
right
=
relay
.
split
(
fc
,
indices_or_sections
=
2
,
axis
=
1
)
one
=
relay
.
const
(
1
,
dtype
=
"float32"
)
one
=
relay
.
const
(
1
,
dtype
=
"float32"
)
return
relay
.
Tuple
([(
left
+
one
),
(
right
-
one
),
fc
])
return
relay
.
Tuple
([(
left
+
one
),
(
right
-
one
),
fc
])
...
...
rust/runtime/tests/test_nn/src/main.rs
View file @
9c591510
...
@@ -51,7 +51,7 @@ fn main() {
...
@@ -51,7 +51,7 @@ fn main() {
let
syslib
=
SystemLibModule
::
default
();
let
syslib
=
SystemLibModule
::
default
();
let
mut
params_bytes
=
Vec
::
new
();
let
mut
params_bytes
=
Vec
::
new
();
fs
::
File
::
open
(
concat!
(
env!
(
"OUT_DIR"
),
"/graph.params"
))
fs
::
File
::
open
(
concat!
(
env!
(
"OUT_DIR"
),
"/
test_nn/
graph.params"
))
.unwrap
()
.unwrap
()
.read_to_end
(
&
mut
params_bytes
)
.read_to_end
(
&
mut
params_bytes
)
.unwrap
();
.unwrap
();
...
@@ -61,9 +61,10 @@ fn main() {
...
@@ -61,9 +61,10 @@ fn main() {
.map
(|(
k
,
v
)|
(
k
,
v
.to_owned
()))
.map
(|(
k
,
v
)|
(
k
,
v
.to_owned
()))
.collect
::
<
HashMap
<
String
,
Tensor
<
'static
>>>
();
.collect
::
<
HashMap
<
String
,
Tensor
<
'static
>>>
();
let
graph
=
let
graph
=
Graph
::
try_from
(
Graph
::
try_from
(
&
fs
::
read_to_string
(
concat!
(
env!
(
"OUT_DIR"
),
"/graph.json"
))
.unwrap
())
&
fs
::
read_to_string
(
concat!
(
env!
(
"OUT_DIR"
),
"/test_nn/graph.json"
))
.unwrap
(),
.unwrap
();
)
.unwrap
();
let
mut
exec
=
GraphExecutor
::
new
(
graph
,
&
syslib
)
.unwrap
();
let
mut
exec
=
GraphExecutor
::
new
(
graph
,
&
syslib
)
.unwrap
();
let
x
=
Array
::
from_shape_vec
(
let
x
=
Array
::
from_shape_vec
(
...
@@ -73,11 +74,16 @@ fn main() {
...
@@ -73,11 +74,16 @@ fn main() {
.collect
::
<
Vec
<
f32
>>
(),
.collect
::
<
Vec
<
f32
>>
(),
)
)
.unwrap
();
.unwrap
();
let
w
=
Array
::
try_from
(
params
.get
(
"dense0_weight"
)
.unwrap
()
.to_owned
())
let
p0
=
params
.get
(
"p0"
)
.unwrap
()
.to_owned
();
let
p1
=
params
.get
(
"p1"
)
.unwrap
()
.to_owned
();
println!
(
"p0: {:?}"
,
p0
.shape
());
println!
(
"p1: {:?}"
,
p1
.shape
());
let
w
=
Array
::
try_from
(
p0
)
.unwrap
()
.unwrap
()
.into_shape
((
IN_DIM
*
2
,
IN_DIM
))
.into_shape
((
BATCH_SIZE
*
4
,
IN_DIM
))
.unwrap
();
.unwrap
();
let
b
=
Array
::
try_from
(
p
arams
.get
(
"dense0_bias"
)
.unwrap
()
.to_owned
()
)
.unwrap
();
let
b
=
Array
::
try_from
(
p
1
)
.unwrap
();
let
dense
=
x
.dot
(
&
w
.t
())
+
&
b
;
let
dense
=
x
.dot
(
&
w
.t
())
+
&
b
;
let
left
=
dense
.slice
(
s!
[
..
,
0
..
IN_DIM
]);
let
left
=
dense
.slice
(
s!
[
..
,
0
..
IN_DIM
]);
let
right
=
dense
.slice
(
s!
[
..
,
IN_DIM
..
]);
let
right
=
dense
.slice
(
s!
[
..
,
IN_DIM
..
]);
...
@@ -88,8 +94,8 @@ fn main() {
...
@@ -88,8 +94,8 @@ fn main() {
exec
.set_input
(
"data"
,
(
&
x
)
.into
());
exec
.set_input
(
"data"
,
(
&
x
)
.into
());
check_sum!
(
exec
,
data
,
x
);
check_sum!
(
exec
,
data
,
x
);
check_sum!
(
exec
,
dense0_weight
,
w
);
check_sum!
(
exec
,
p0
,
w
);
check_sum!
(
exec
,
dense0_bias
,
b
);
check_sum!
(
exec
,
p1
,
b
);
exec
.run
();
exec
.run
();
...
...
tests/scripts/task_rust.sh
View file @
9c591510
...
@@ -19,15 +19,13 @@
...
@@ -19,15 +19,13 @@
set
-e
set
-e
set
-u
set
-u
# Temporary disable rust tests
# remove this line to re-enable.
exit
0
export
TVM_HOME
=
"
$(
git rev-parse
--show-toplevel
)
"
export
TVM_HOME
=
"
$(
git rev-parse
--show-toplevel
)
"
export
LD_LIBRARY_PATH
=
"
$TVM_HOME
/lib:
$TVM_HOME
/build:
${
LD_LIBRARY_PATH
:-}
"
export
LD_LIBRARY_PATH
=
"
$TVM_HOME
/lib:
$TVM_HOME
/build:
${
LD_LIBRARY_PATH
:-}
"
export
PYTHONPATH
=
"
$TVM_HOME
/python"
:
"
$TVM_HOME
/topi/python"
export
PYTHONPATH
=
"
$TVM_HOME
/python"
:
"
$TVM_HOME
/topi/python"
export
RUST_DIR
=
"
$TVM_HOME
/rust"
export
RUST_DIR
=
"
$TVM_HOME
/rust"
export
LLVM_CONFIG_PATH
=
`
which llvm-config-8
`
echo
"Using
$LLVM_CONFIG_PATH
"
cd
$RUST_DIR
cd
$RUST_DIR
cargo fmt
--
--check
cargo fmt
--
--check
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment