Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
14a0ecba
Commit
14a0ecba
authored
Apr 07, 2019
by
Nick Hynes
Committed by
Tianqi Chen
Apr 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rustify PackedFunc & Friends (#2969)
parent
0708c48d
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
106 additions
and
229 deletions
+106
-229
rust/common/Cargo.toml
+1
-0
rust/common/src/errors.rs
+6
-57
rust/common/src/lib.rs
+2
-2
rust/common/src/packed_func.rs
+0
-0
rust/common/src/value.rs
+15
-0
rust/frontend/Cargo.toml
+1
-0
rust/frontend/src/bytearray.rs
+5
-16
rust/frontend/src/context.rs
+5
-18
rust/frontend/src/function.rs
+11
-13
rust/frontend/src/module.rs
+5
-2
rust/frontend/src/ndarray.rs
+1
-1
rust/frontend/src/value.rs
+40
-100
rust/frontend/tests/callback/src/bin/string.rs
+3
-3
rust/runtime/Cargo.toml
+1
-0
rust/runtime/src/graph.rs
+2
-2
rust/runtime/src/module.rs
+8
-11
rust/runtime/tests/test_graph_serde.rs
+0
-2
rust/runtime/tests/test_nnvm/src/main.rs
+0
-2
No files found.
rust/common/Cargo.toml
View file @
14a0ecba
...
...
@@ -3,6 +3,7 @@ name = "tvm-common"
version
=
"0.1.0"
authors
=
[
"TVM Contributors"
]
license
=
"Apache-2.0"
edition
=
"2018"
[features]
bindings
=
[]
...
...
rust/common/src/errors.rs
View file @
14a0ecba
use
std
::
fmt
;
static
TYPE_CODE_STRS
:
[
&
str
;
15
]
=
[
"int"
,
"uint"
,
"float"
,
"handle"
,
"null"
,
"TVMType"
,
"TVMContext"
,
"ArrayHandle"
,
"NodeHandle"
,
"ModuleHandle"
,
"FuncHandle"
,
"str"
,
"bytes"
,
"NDArrayContainer"
,
"ExtBegin"
,
];
#
[
derive
(
Debug
,
Fail
)]
#[fail(
display
=
"Could not downcast `{}` into `{}`"
,
expected_type,
actual_type
)]
pub
struct
ValueDowncastError
{
actual_type_code
:
i64
,
expected_type_code
:
i64
,
}
impl
ValueDowncastError
{
pub
fn
new
(
actual_type_code
:
i64
,
expected_type_code
:
i64
)
->
Self
{
Self
{
actual_type_code
,
expected_type_code
,
}
}
}
impl
fmt
::
Display
for
ValueDowncastError
{
fn
fmt
(
&
self
,
formatter
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
write!
(
formatter
,
"Could not downcast TVMValue: expected `{}` but was {}"
,
TYPE_CODE_STRS
[
self
.actual_type_code
as
usize
],
TYPE_CODE_STRS
[
self
.expected_type_code
as
usize
]
)
}
pub
actual_type
:
String
,
pub
expected_type
:
&
'static
str
,
}
#[derive(Debug,
Fail)]
...
...
@@ -62,18 +26,3 @@ impl FuncCallError {
}
}
}
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
rust/common/src/lib.rs
View file @
14a0ecba
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#
!
[
feature
(
box_syntax
,
trait_alias
)]
#
!
[
feature
(
box_syntax
,
t
ype_alias_enum_variants
,
t
rait_alias
)]
#[macro_use]
extern
crate
failure
;
...
...
@@ -25,5 +25,5 @@ pub mod packed_func;
pub
mod
value
;
pub
use
errors
::
*
;
pub
use
ffi
::{
TVMContext
,
TVMType
};
pub
use
ffi
::{
TVM
ByteArray
,
TVM
Context
,
TVMType
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
rust/common/src/packed_func.rs
View file @
14a0ecba
This diff is collapsed.
Click to expand it.
rust/common/src/value.rs
View file @
14a0ecba
...
...
@@ -137,3 +137,18 @@ impl_tvm_context!(
DLDeviceType_kDLROCM
:
[
rocm
],
DLDeviceType_kDLExtDev
:
[
ext_dev
]
);
impl
TVMByteArray
{
pub
fn
data
(
&
self
)
->
&
'static
[
u8
]
{
unsafe
{
std
::
slice
::
from_raw_parts
(
self
.data
as
*
const
u8
,
self
.size
)
}
}
}
impl
<
'a
>
From
<&
'a
[
u8
]
>
for
TVMByteArray
{
fn
from
(
bytes
:
&
[
u8
])
->
Self
{
Self
{
data
:
bytes
.as_ptr
()
as
*
const
i8
,
size
:
bytes
.len
(),
}
}
}
rust/frontend/Cargo.toml
View file @
14a0ecba
...
...
@@ -9,6 +9,7 @@ readme = "README.md"
keywords
=
[
"rust"
,
"tvm"
,
"nnvm"
]
categories
=
[
"api-bindings"
,
"science"
]
authors
=
[
"TVM Contributors"
]
edition
=
"2018"
[lib]
name
=
"tvm_frontend"
...
...
rust/frontend/src/bytearray.rs
View file @
14a0ecba
...
...
@@ -3,9 +3,9 @@
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use
std
::
os
::
raw
::
{
c_char
,
c_void
}
;
use
std
::
os
::
raw
::
c_char
;
use
tvm_common
::
{
ffi
,
TVMArgValue
}
;
use
tvm_common
::
ffi
;
/// A struct holding TVM byte-array.
///
...
...
@@ -44,8 +44,9 @@ impl TVMByteArray {
}
}
impl
<
'a
>
From
<&
'a
Vec
<
u8
>>
for
TVMByteArray
{
fn
from
(
arg
:
&
Vec
<
u8
>
)
->
Self
{
impl
<
'a
,
T
:
AsRef
<
[
u8
]
>>
From
<
T
>
for
TVMByteArray
{
fn
from
(
arg
:
T
)
->
Self
{
let
arg
=
arg
.as_ref
();
let
barr
=
ffi
::
TVMByteArray
{
data
:
arg
.as_ptr
()
as
*
const
c_char
,
size
:
arg
.len
(),
...
...
@@ -54,18 +55,6 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
}
}
impl
<
'a
>
From
<&
TVMByteArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
TVMByteArray
)
->
Self
{
Self
{
value
:
ffi
::
TVMValue
{
v_handle
:
&
arr
.inner
as
*
const
ffi
::
TVMByteArray
as
*
const
c_void
as
*
mut
c_void
,
},
type_code
:
ffi
::
TVMTypeCode_kBytes
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
...
...
rust/frontend/src/context.rs
View file @
14a0ecba
...
...
@@ -26,10 +26,7 @@ use std::{
use
failure
::
Error
;
use
tvm_common
::{
ffi
::{
self
,
TVMValue
},
TVMArgValue
,
};
use
tvm_common
::
ffi
;
use
crate
::
function
;
...
...
@@ -125,18 +122,6 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
impl
<
'a
>
From
<&
'a
TVMDeviceType
>
for
TVMArgValue
<
'a
>
{
fn
from
(
dev_type
:
&
'a
TVMDeviceType
)
->
Self
{
Self
{
value
:
TVMValue
{
v_int64
:
dev_type
.
0
as
i64
,
},
type_code
:
ffi
::
DLDataTypeCode_kDLInt
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
...
...
@@ -209,7 +194,7 @@ impl TVMContext {
let
dt
=
self
.device_type
.
0
as
usize
;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let
ret
:
u64
=
call_packed!
(
func
,
&
dt
,
&
self
.device_id
,
&
0
)
let
ret
:
u64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
.unwrap
()
.try_into
()
.unwrap
();
...
...
@@ -238,7 +223,9 @@ macro_rules! impl_device_attrs {
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function
::
Builder
::
from
(
func
)
.args
(
&
[
dt
,
self
.device_id
as
usize
,
$attr_kind
])
.arg
(
dt
)
.arg
(
self
.device_id
as
usize
)
.arg
(
$attr_kind
)
.invoke
()
.unwrap
()
.try_into
()
...
...
rust/frontend/src/function.rs
View file @
14a0ecba
...
...
@@ -156,9 +156,9 @@ impl<'a, 'm> Builder<'a, 'm> {
}
/// Pushes a [`TVMArgValue`] into the function argument buffer.
pub
fn
arg
<
T
:
'a
>
(
&
mut
self
,
arg
:
&
'a
T
)
->
&
mut
Self
pub
fn
arg
<
T
:
'a
>
(
&
mut
self
,
arg
:
T
)
->
&
mut
Self
where
TVMArgValue
<
'a
>
:
From
<
&
'a
T
>
,
TVMArgValue
<
'a
>
:
From
<
T
>
,
{
self
.arg_buf
.push
(
arg
.into
());
self
...
...
@@ -192,14 +192,11 @@ impl<'a, 'm> Builder<'a, 'm> {
ensure!
(
self
.func
.is_some
(),
errors
::
FunctionNotFoundError
);
let
num_args
=
self
.arg_buf
.len
();
let
(
mut
values
,
mut
type_codes
):
(
Vec
<
ffi
::
TVMValue
>
,
Vec
<
ffi
::
TVMTypeCode
>
)
=
self
.arg_buf
.iter
()
.map
(|
tvm_arg
|
(
tvm_arg
.value
,
tvm_arg
.type_code
as
ffi
::
TVMTypeCode
))
.unzip
();
let
(
mut
values
,
mut
type_codes
):
(
Vec
<
ffi
::
TVMValue
>
,
Vec
<
ffi
::
TVMTypeCode
>
)
=
self
.arg_buf
.iter
()
.map
(|
arg
|
arg
.to_tvm_value
())
.unzip
();
let
mut
ret_val
=
unsafe
{
std
::
mem
::
uninitialized
::
<
TVMValue
>
()
};
let
mut
ret_type_code
=
0
;
let
mut
ret_type_code
=
0
i32
;
check_call!
(
ffi
::
TVMFuncCall
(
self
.func
.ok_or
(
errors
::
FunctionNotFoundError
)
?
.handle
,
values
.as_mut_ptr
(),
...
...
@@ -209,7 +206,7 @@ impl<'a, 'm> Builder<'a, 'm> {
&
mut
ret_type_code
as
*
mut
_
));
Ok
(
unsafe
{
TVMRetValue
::
from_tvm_value
(
ret_val
,
ret_type_code
as
i64
)
})
Ok
(
unsafe
{
TVMRetValue
::
from_tvm_value
(
ret_val
,
ret_type_code
as
u32
)
})
}
}
...
...
@@ -254,7 +251,7 @@ unsafe extern "C" fn tvm_callback(
{
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
tcode
));
}
local_args
.push
(
TVMArgValue
::
new
(
value
.into
(),
(
tcode
as
i64
)
.into
()
));
local_args
.push
(
TVMArgValue
::
from_tvm_value
(
value
.into
(),
tcode
as
u32
));
}
let
rv
=
match
rust_fn
(
local_args
.as_slice
())
{
...
...
@@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback(
}
};
let
(
mut
ret_val
,
ret_tcode
)
=
rv
.
in
to_tvm_value
();
let
(
mut
ret_val
,
ret_tcode
)
=
rv
.to_tvm_value
();
let
mut
ret_type_code
=
ret_tcode
as
c_int
;
check_call!
(
ffi
::
TVMCFuncSetReturn
(
ret
,
...
...
@@ -437,8 +434,9 @@ mod tests {
let
str_arg
=
CString
::
new
(
"test"
)
.unwrap
();
let
mut
func
=
Builder
::
default
();
func
.get_function
(
"tvm.graph_runtime.remote_create"
)
.args
(
&
[
10
,
20
])
.arg
(
&
str_arg
);
.arg
(
10
)
.arg
(
20
)
.arg
(
str_arg
.as_c_str
());
assert_eq!
(
func
.arg_buf
.len
(),
3
);
}
}
rust/frontend/src/module.rs
View file @
14a0ecba
...
...
@@ -80,7 +80,7 @@ impl Module {
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
})
?
)
?
;
let
ret
:
Module
=
call_packed!
(
func
,
&
cpath
,
&
ext
)
?
.try_into
()
?
;
let
ret
:
Module
=
call_packed!
(
func
,
cpath
.as_c_str
(),
ext
.as_c_str
()
)
?
.try_into
()
?
;
Ok
(
ret
)
}
...
...
@@ -90,7 +90,10 @@ impl Module {
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
let
ret
:
i64
=
call_packed!
(
func
,
&
tgt
)
.unwrap
()
.try_into
()
.unwrap
();
let
ret
:
i64
=
call_packed!
(
func
,
tgt
.as_c_str
())
.unwrap
()
.try_into
()
.unwrap
();
ret
!=
0
}
...
...
rust/frontend/src/ndarray.rs
View file @
14a0ecba
...
...
@@ -161,7 +161,7 @@ impl NDArray {
/// Converts the NDArray to [`TVMByteArray`].
pub
fn
to_bytearray
(
&
self
)
->
Result
<
TVMByteArray
,
Error
>
{
let
v
=
self
.to_vec
::
<
u8
>
()
?
;
Ok
(
TVMByteArray
::
from
(
&
v
))
Ok
(
TVMByteArray
::
from
(
v
))
}
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
...
...
rust/frontend/src/value.rs
View file @
14a0ecba
...
...
@@ -2,140 +2,80 @@
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use
std
::
{
convert
::
TryFrom
,
os
::
raw
::
c_void
}
;
use
std
::
convert
::
TryFrom
;
use
failure
::
Error
;
use
tvm_common
::{
ensure_type
,
ffi
::{
self
,
TVMValue
},
errors
::
ValueDowncastError
,
ffi
::{
TVMArrayHandle
,
TVMFunctionHandle
,
TVMModuleHandle
},
try_downcast
,
};
use
crate
::{
common_errors
::
*
,
context
::
TVMContext
,
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMByteArray
,
TVMRetValue
,
};
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
macro_rules!
impl_tvm_val_from_handle
{
(
$ty:ident
,
$type_code:expr
,
$handle:ty
)
=>
{
impl
<
'a
>
From
<&
'a
$ty
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
$ty
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
arg
.handle
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
macro_rules!
impl_handle_val
{
(
$type:ty
,
$variant:ident
,
$inner_type:ty
,
$ctor:path
)
=>
{
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
$type
)
->
Self
{
TVMArgValue
::
$variant
(
arg
.handle
()
as
$inner_type
)
}
}
impl
<
'a
>
From
<&
'a
mut
$ty
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
mut
$ty
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
arg
.handle
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
impl
<
'a
>
From
<&
'a
mut
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
mut
$type
)
->
Self
{
TVMArgValue
::
$variant
(
arg
.handle
()
as
$inner_type
)
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'v
>
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
arg
,
$type_code
);
Ok
(
$ty
::
new
(
unsafe
{
arg
.value.v_handle
as
$handle
}))
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
$type
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
$type
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
$ctor
(
val
)
})
}
}
impl
From
<
$ty
>
for
TVMRetValue
{
fn
from
(
val
:
$ty
)
->
TVMRetValue
{
TVMRetValue
{
value
:
TVMValue
{
v_handle
:
val
.handle
()
as
*
mut
c_void
,
},
box_value
:
box
val
,
type_code
:
$type_code
as
i64
,
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$type
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
$type
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
$ctor
(
*
val
)
})
}
}
impl
TryFrom
<
TVMRetValue
>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
ret
,
$type_code
);
Ok
(
$ty
::
new
(
unsafe
{
ret
.value.v_handle
as
$handle
}))
}
}
};
}
impl_tvm_val_from_handle!
(
Function
,
ffi
::
TVMTypeCode_kFuncHandle
,
ffi
::
TVMFunctionHandle
);
impl_tvm_val_from_handle!
(
Module
,
ffi
::
TVMTypeCode_kModuleHandle
,
ffi
::
TVMModuleHandle
);
impl_tvm_val_from_handle!
(
NDArray
,
ffi
::
TVMTypeCode_kArrayHandle
,
ffi
::
TVMArrayHandle
);
impl
<
'a
>
From
<&
'a
TVMByteArray
>
for
TVMValue
{
fn
from
(
barr
:
&
TVMByteArray
)
->
Self
{
TVMValue
{
v_handle
:
&
barr
.inner
as
*
const
ffi
::
TVMByteArray
as
*
mut
c_void
,
}
}
}
macro_rules!
impl_boxed_ret_value
{
(
$type:ty
,
$code:expr
)
=>
{
impl
From
<
$type
>
for
TVMRetValue
{
fn
from
(
val
:
$type
)
->
Self
{
TVMRetValue
{
value
:
TVMValue
{
v_int64
:
0
},
box_value
:
box
val
,
type_code
:
$code
as
i64
,
}
fn
from
(
val
:
$type
)
->
TVMRetValue
{
TVMRetValue
::
$variant
(
val
.handle
()
as
$inner_type
)
}
}
impl
TryFrom
<
TVMRetValue
>
for
$type
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$type
,
Self
::
Error
>
{
if
let
Ok
(
val
)
=
ret
.box_value.downcast
::
<
$type
>
()
{
Ok
(
*
val
)
}
else
{
bail!
(
ValueDowncastError
::
new
(
$code
as
i64
,
ret
.type_code
as
i64
))
}
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
$type
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMRetValue
::
$variant
(
val
)|
{
$ctor
(
val
)
})
}
}
};
}
impl_boxed_ret_value!
(
TVMContext
,
ffi
::
TVMTypeCode_kTVMContext
);
impl_boxed_ret_value!
(
TVMByteArray
,
ffi
::
TVMTypeCode_kBytes
);
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
TVMByteArray
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
ffi
::
TVMTypeCode_kBytes
);
Ok
(
TVMByteArray
::
new
(
unsafe
{
*
(
arg
.value.v_handle
as
*
mut
ffi
::
TVMByteArray
)
}))
}
}
impl_handle_val!
(
Function
,
FuncHandle
,
TVMFunctionHandle
,
Function
::
new
);
impl_handle_val!
(
Module
,
ModuleHandle
,
TVMModuleHandle
,
Module
::
new
);
impl_handle_val!
(
NDArray
,
ArrayHandle
,
TVMArrayHandle
,
NDArray
::
new
);
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
std
::{
convert
::
TryInto
,
str
::
FromStr
};
use
tvm_common
::
ffi
::
TVMType
;
use
tvm_common
::{
TVMByteArray
,
TVMContext
,
TVMType
};
use
super
::
*
;
#[test]
fn
bytearray
()
{
let
w
=
vec!
[
1u8
,
2
,
3
,
4
,
5
];
let
v
=
TVMByteArray
::
from
(
&
w
);
let
v
=
TVMByteArray
::
from
(
w
.as_slice
()
);
let
tvm
:
TVMByteArray
=
TVMRetValue
::
from
(
v
)
.try_into
()
.unwrap
();
assert_eq!
(
tvm
.data
(),
w
.iter
()
.map
(|
e
|
*
e
as
i8
)
.collect
::
<
Vec
<
i8
>>
());
assert_eq!
(
tvm
.data
(),
w
.iter
()
.map
(|
e
|
*
e
)
.collect
::
<
Vec
<
u8
>>
()
.as_slice
()
);
}
#[test]
...
...
@@ -147,7 +87,7 @@ mod tests {
#[test]
fn
ctx
()
{
let
c
=
TVMContext
::
from
(
"gpu"
);
let
c
=
TVMContext
::
from
_str
(
"gpu"
)
.unwrap
(
);
let
tvm
:
TVMContext
=
TVMRetValue
::
from
(
c
)
.try_into
()
.unwrap
();
assert_eq!
(
tvm
,
c
);
}
...
...
rust/frontend/tests/callback/src/bin/string.rs
View file @
14a0ecba
...
...
@@ -24,9 +24,9 @@ fn main() {
registered
.get_function
(
"concate_str"
);
assert
!
(
registered
.func
.is_some
());
let
ret
:
String
=
registered
.arg
(
&
a
)
.arg
(
&
b
)
.arg
(
&
c
)
.arg
(
a
.as_c_str
()
)
.arg
(
b
.as_c_str
()
)
.arg
(
c
.as_c_str
()
)
.invoke
()
.unwrap
()
.try_into
()
...
...
rust/runtime/Cargo.toml
View file @
14a0ecba
...
...
@@ -8,6 +8,7 @@ readme = "README.md"
keywords
=
[
"tvm"
,
"nnvm"
]
categories
=
[
"api-bindings"
,
"science"
]
authors
=
[
"TVM Contributors"
]
edition
=
"2018"
[features]
default
=
["nom/std"]
...
...
rust/runtime/src/graph.rs
View file @
14a0ecba
...
...
@@ -265,7 +265,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
.iter
()
.map
(|
t
|
t
.into
())
.collect
::
<
Vec
<
TVMArgValue
>>
();
func
(
args
.as_slice
()
)
.unwrap
();
func
(
&
args
)
.unwrap
();
};
op_execs
.push
(
op
);
}
...
...
@@ -283,7 +283,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
// TODO: consider `new_with_params` to avoid ever allocating
let
ptr
=
self
.tensors
[
idx
]
.data
.as_ptr
();
let
mut
to_replace
=
self
.tensors
.iter_mut
()
.filter
(|
t
|
t
.data
.as_ptr
()
==
ptr
);
let
mut
owner
=
to_replace
.nth
(
0
)
.unwrap
();
let
owner
=
to_replace
.nth
(
0
)
.unwrap
();
if
value
.data
.is_owned
()
{
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
// mem::replace(&mut (*owner), value);
...
...
rust/runtime/src/module.rs
View file @
14a0ecba
...
...
@@ -40,17 +40,14 @@ pub(super) fn wrap_backend_packed_func(
func
:
BackendPackedCFunc
,
)
->
Box
<
dyn
PackedFunc
>
{
box
move
|
args
:
&
[
TVMArgValue
]|
{
let
exit_code
=
func
(
args
.iter
()
.map
(|
ref
arg
|
arg
.value
)
.collect
::
<
Vec
<
TVMValue
>>
()
.as_ptr
(),
args
.iter
()
.map
(|
ref
arg
|
arg
.type_code
as
i32
)
.collect
::
<
Vec
<
i32
>>
()
.as_ptr
()
as
*
const
i32
,
args
.len
()
as
i32
,
);
let
(
values
,
type_codes
):
(
Vec
<
TVMValue
>
,
Vec
<
i32
>
)
=
args
.into_iter
()
.map
(|
arg
|
{
let
(
val
,
code
)
=
arg
.to_tvm_value
();
(
val
,
code
as
i32
)
})
.unzip
();
let
exit_code
=
func
(
values
.as_ptr
(),
type_codes
.as_ptr
(),
values
.len
()
as
i32
);
if
exit_code
==
0
{
Ok
(
TVMRetValue
::
default
())
}
else
{
...
...
rust/runtime/tests/test_graph_serde.rs
View file @
14a0ecba
#
!
[
feature
(
try_from
)]
extern
crate
serde
;
extern
crate
serde_json
;
...
...
rust/runtime/tests/test_nnvm/src/main.rs
View file @
14a0ecba
#
!
[
feature
(
try_from
)]
#
[
macro_use
]
extern
crate
ndarray
;
extern
crate
serde
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment