Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
14a0ecba
Commit
14a0ecba
authored
Apr 07, 2019
by
Nick Hynes
Committed by
Tianqi Chen
Apr 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rustify PackedFunc & Friends (#2969)
parent
0708c48d
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
341 additions
and
452 deletions
+341
-452
rust/common/Cargo.toml
+1
-0
rust/common/src/errors.rs
+6
-57
rust/common/src/lib.rs
+2
-2
rust/common/src/packed_func.rs
+235
-223
rust/common/src/value.rs
+15
-0
rust/frontend/Cargo.toml
+1
-0
rust/frontend/src/bytearray.rs
+5
-16
rust/frontend/src/context.rs
+5
-18
rust/frontend/src/function.rs
+11
-13
rust/frontend/src/module.rs
+5
-2
rust/frontend/src/ndarray.rs
+1
-1
rust/frontend/src/value.rs
+40
-100
rust/frontend/tests/callback/src/bin/string.rs
+3
-3
rust/runtime/Cargo.toml
+1
-0
rust/runtime/src/graph.rs
+2
-2
rust/runtime/src/module.rs
+8
-11
rust/runtime/tests/test_graph_serde.rs
+0
-2
rust/runtime/tests/test_nnvm/src/main.rs
+0
-2
No files found.
rust/common/Cargo.toml
View file @
14a0ecba
...
@@ -3,6 +3,7 @@ name = "tvm-common"
...
@@ -3,6 +3,7 @@ name = "tvm-common"
version
=
"0.1.0"
version
=
"0.1.0"
authors
=
[
"TVM Contributors"
]
authors
=
[
"TVM Contributors"
]
license
=
"Apache-2.0"
license
=
"Apache-2.0"
edition
=
"2018"
[features]
[features]
bindings
=
[]
bindings
=
[]
...
...
rust/common/src/errors.rs
View file @
14a0ecba
use
std
::
fmt
;
static
TYPE_CODE_STRS
:
[
&
str
;
15
]
=
[
"int"
,
"uint"
,
"float"
,
"handle"
,
"null"
,
"TVMType"
,
"TVMContext"
,
"ArrayHandle"
,
"NodeHandle"
,
"ModuleHandle"
,
"FuncHandle"
,
"str"
,
"bytes"
,
"NDArrayContainer"
,
"ExtBegin"
,
];
#
[
derive
(
Debug
,
Fail
)]
#
[
derive
(
Debug
,
Fail
)]
#[fail(
display
=
"Could not downcast `{}` into `{}`"
,
expected_type,
actual_type
)]
pub
struct
ValueDowncastError
{
pub
struct
ValueDowncastError
{
actual_type_code
:
i64
,
pub
actual_type
:
String
,
expected_type_code
:
i64
,
pub
expected_type
:
&
'static
str
,
}
impl
ValueDowncastError
{
pub
fn
new
(
actual_type_code
:
i64
,
expected_type_code
:
i64
)
->
Self
{
Self
{
actual_type_code
,
expected_type_code
,
}
}
}
impl
fmt
::
Display
for
ValueDowncastError
{
fn
fmt
(
&
self
,
formatter
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
write!
(
formatter
,
"Could not downcast TVMValue: expected `{}` but was {}"
,
TYPE_CODE_STRS
[
self
.actual_type_code
as
usize
],
TYPE_CODE_STRS
[
self
.expected_type_code
as
usize
]
)
}
}
}
#[derive(Debug,
Fail)]
#[derive(Debug,
Fail)]
...
@@ -62,18 +26,3 @@ impl FuncCallError {
...
@@ -62,18 +26,3 @@ impl FuncCallError {
}
}
}
}
}
}
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
rust/common/src/lib.rs
View file @
14a0ecba
//! This crate contains the refactored basic components required
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
//! for `runtime` and `frontend` TVM crates.
#
!
[
feature
(
box_syntax
,
trait_alias
)]
#
!
[
feature
(
box_syntax
,
t
ype_alias_enum_variants
,
t
rait_alias
)]
#[macro_use]
#[macro_use]
extern
crate
failure
;
extern
crate
failure
;
...
@@ -25,5 +25,5 @@ pub mod packed_func;
...
@@ -25,5 +25,5 @@ pub mod packed_func;
pub
mod
value
;
pub
mod
value
;
pub
use
errors
::
*
;
pub
use
errors
::
*
;
pub
use
ffi
::{
TVMContext
,
TVMType
};
pub
use
ffi
::{
TVM
ByteArray
,
TVM
Context
,
TVMType
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
rust/common/src/packed_func.rs
View file @
14a0ecba
use
std
::{
any
::
Any
,
convert
::
TryFrom
,
marker
::
PhantomData
,
os
::
raw
::
c_void
};
use
std
::{
convert
::
TryFrom
,
use
failure
::
Error
;
ffi
::{
CStr
,
CString
},
os
::
raw
::
c_void
,
};
pub
use
crate
::
ffi
::
TVMValue
;
pub
use
crate
::
ffi
::
TVMValue
;
use
crate
::
ffi
::
*
;
use
crate
::
{
errors
::
ValueDowncastError
,
ffi
::
*
}
;
pub
trait
PackedFunc
=
pub
trait
PackedFunc
=
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
;
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
;
...
@@ -15,298 +17,308 @@ pub trait PackedFunc =
...
@@ -15,298 +17,308 @@ pub trait PackedFunc =
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
#[macro_export]
macro_rules!
call_packed
{
macro_rules!
call_packed
{
(
$fn:expr
,
$
(
$args:expr
),
+
)
=>
{
(
$fn:expr
,
$
(
$args:expr
),
+
)
=>
{
$fn
(
&
[
$
(
$args
.into
(),)
+
])
$fn
(
&
[
$
(
$args
.into
(),)
+
])
};
};
(
$fn:expr
)
=>
{
(
$fn:expr
)
=>
{
$fn
(
&
Vec
::
new
())
$fn
(
&
Vec
::
new
())
};
};
}
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// Constructs a derivative of a TVMPodValue.
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
macro_rules!
TVMPODValue
{
#[derive(Clone,
Copy)]
{
pub
struct
TVMArgValue
<
'a
>
{
$
(
#
[
$m:meta
])
+
pub
_lifetime
:
PhantomData
<&
'a
()
>
,
$name:ident
$
(
<
$a:lifetime
>
)
?
{
pub
value
:
TVMValue
,
$
(
$extra_variant:ident
(
$variant_type:ty
)
),
+
$
(,)
?
pub
type_code
:
i64
,
},
match
$value:ident
{
$
(
$tvm_type:ident
=>
{
$from_tvm_type:expr
})
+
},
match
&
self
{
$
(
$self_type:ident
(
$val:ident
)
=>
{
$from_self_type:expr
})
+
}
$
(,)
?
}
=>
{
$
(
#
[
$m
])
+
#[derive(Clone,
Debug)]
pub
enum
$name
$
(
<
$a
>
)
?
{
Int
(
i64
),
UInt
(
i64
),
Float
(
f64
),
Null
,
Type
(
TVMType
),
String
(
CString
),
Context
(
TVMContext
),
Handle
(
*
mut
c_void
),
ArrayHandle
(
TVMArrayHandle
),
NodeHandle
(
*
mut
c_void
),
ModuleHandle
(
TVMModuleHandle
),
FuncHandle
(
TVMFunctionHandle
),
NDArrayContainer
(
*
mut
c_void
),
$
(
$extra_variant
(
$variant_type
)),
+
}
impl
$
(
<
$a
>
)
?
$name
$
(
<
$a
>
)
?
{
pub
fn
from_tvm_value
(
$value
:
TVMValue
,
type_code
:
u32
)
->
Self
{
use
$name
::
*
;
#[allow(non_upper_case_globals)]
unsafe
{
match
type_code
{
DLDataTypeCode_kDLInt
=>
Int
(
$value
.v_int64
),
DLDataTypeCode_kDLUInt
=>
UInt
(
$value
.v_int64
),
DLDataTypeCode_kDLFloat
=>
Float
(
$value
.v_float64
),
TVMTypeCode_kNull
=>
Null
,
TVMTypeCode_kTVMType
=>
Type
(
$value
.v_type
),
TVMTypeCode_kTVMContext
=>
Context
(
$value
.v_ctx
),
TVMTypeCode_kHandle
=>
Handle
(
$value
.v_handle
),
TVMTypeCode_kArrayHandle
=>
ArrayHandle
(
$value
.v_handle
as
TVMArrayHandle
),
TVMTypeCode_kNodeHandle
=>
NodeHandle
(
$value
.v_handle
),
TVMTypeCode_kModuleHandle
=>
ModuleHandle
(
$value
.v_handle
),
TVMTypeCode_kFuncHandle
=>
FuncHandle
(
$value
.v_handle
),
TVMTypeCode_kNDArrayContainer
=>
NDArrayContainer
(
$value
.v_handle
),
$
(
$tvm_type
=>
{
$from_tvm_type
}
),
+
_
=>
unimplemented!
(
"{}"
,
type_code
),
}
}
}
pub
fn
to_tvm_value
(
&
self
)
->
(
TVMValue
,
TVMTypeCode
)
{
use
$name
::
*
;
match
self
{
Int
(
val
)
=>
(
TVMValue
{
v_int64
:
*
val
},
DLDataTypeCode_kDLInt
),
UInt
(
val
)
=>
(
TVMValue
{
v_int64
:
*
val
as
i64
},
DLDataTypeCode_kDLUInt
),
Float
(
val
)
=>
(
TVMValue
{
v_float64
:
*
val
},
DLDataTypeCode_kDLFloat
),
Null
=>
(
TVMValue
{
v_int64
:
0
},
TVMTypeCode_kNull
),
Type
(
val
)
=>
(
TVMValue
{
v_type
:
*
val
},
TVMTypeCode_kTVMType
),
Context
(
val
)
=>
(
TVMValue
{
v_ctx
:
val
.clone
()
},
TVMTypeCode_kTVMContext
),
String
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
.as_ptr
()
as
*
mut
c_void
},
TVMTypeCode_kStr
,
)
}
Handle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kHandle
),
ArrayHandle
(
val
)
=>
{
(
TVMValue
{
v_handle
:
*
val
as
*
const
_
as
*
mut
c_void
},
TVMTypeCode_kArrayHandle
,
)
},
NodeHandle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kNodeHandle
),
ModuleHandle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kModuleHandle
),
FuncHandle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kFuncHandle
),
NDArrayContainer
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kNDArrayContainer
),
$
(
$self_type
(
$val
)
=>
{
$from_self_type
}
),
+
}
}
}
}
}
}
impl
<
'a
>
TVMArgValue
<
'a
>
{
TVMPODValue!
{
pub
fn
new
(
value
:
TVMValue
,
type_code
:
i64
)
->
Self
{
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
TVMArgValue
{
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
_lifetime
:
PhantomData
,
TVMArgValue
<
'a
>
{
value
:
value
,
Bytes
(
&
'a
TVMByteArray
),
type_code
:
type_code
,
Str
(
&
'a
CStr
),
},
match
value
{
TVMTypeCode_kBytes
=>
{
Bytes
(
&*
(
value
.v_handle
as
*
const
TVMByteArray
))
}
TVMTypeCode_kStr
=>
{
Str
(
CStr
::
from_ptr
(
value
.v_handle
as
*
const
i8
))
}
},
match
&
self
{
Bytes
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
.clone
()
as
*
const
_
as
*
mut
c_void
},
TVMTypeCode_kBytes
)
}
}
Str
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
.as_ptr
()
as
*
mut
c_void
},
TVMTypeCode_kStr
)}
}
}
TVMPODValue!
{
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
TVMRetValue
{
Bytes
(
TVMByteArray
),
Str
(
&
'static
CStr
),
},
match
value
{
TVMTypeCode_kBytes
=>
{
Bytes
(
*
(
value
.v_handle
as
*
const
TVMByteArray
))
}
TVMTypeCode_kStr
=>
{
Str
(
CStr
::
from_ptr
(
value
.v_handle
as
*
mut
i8
))
}
},
match
&
self
{
Bytes
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
as
*
const
_
as
*
mut
c_void
},
TVMTypeCode_kBytes
)
}
Str
(
val
)
=>
{
(
TVMValue
{
v_str
:
val
.as_ptr
()
},
TVMTypeCode_kStr
)
}
}
}
}
}
#[macro_export]
#[macro_export]
macro_rules!
ensure_type
{
macro_rules!
try_downcast
{
(
$val:ident
,
$expected_type_code:expr
)
=>
{
(
$val:ident
->
$into:ty
,
$
(
|
$pat:pat
|
{
$converter:expr
}
),
+
)
=>
{
ensure!
(
match
$val
{
$
val
.type_code
==
$expected_type_code
as
i64
,
$
(
$pat
=>
{
Ok
(
$converter
)
}
)
+
$crate
::
errors
::
ValueDowncastError
::
new
(
_
=>
Err
(
$crate
::
errors
::
ValueDowncastError
{
$val
.type_code
as
i64
,
actual_type
:
format!
(
"{:?}"
,
$val
)
,
$expected_type_code
as
i64
expected_type
:
stringify!
(
$into
),
)
}),
);
}
};
};
}
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules!
impl_p
rim_tvm_arg
{
macro_rules!
impl_p
od_value
{
(
$
type_code:ident
,
$field:ident
,
$field_type
:ty
,
[
$
(
$type:ty
),
+
]
)
=>
{
(
$
variant:ident
,
$inner_ty
:ty
,
[
$
(
$type:ty
),
+
]
)
=>
{
$
(
$
(
impl
From
<
$type
>
for
TVMArgValue
<
'static
>
{
impl
<
'a
>
From
<
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
val
:
$type
)
->
Self
{
fn
from
(
val
:
$type
)
->
Self
{
TVMArgValue
{
Self
::
$variant
(
val
as
$inner_ty
)
value
:
TVMValue
{
$field
:
val
as
$field_type
},
type_code
:
$type_code
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
,
'v
>
From
<&
'a
$type
>
for
TVMArgValue
<
'v
>
{
fn
from
(
val
:
&
'a
$type
)
->
Self
{
fn
from
(
val
:
&
'a
$type
)
->
Self
{
TVMArgValue
{
Self
::
$variant
(
*
val
as
$inner_ty
)
value
:
TVMValue
{
$field
:
val
.to_owned
()
as
$field_type
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
$type
{
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
$type
{
type
Error
=
Error
;
type
Error
=
$crate
::
errors
::
ValueDowncast
Error
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
val
,
$type_code
);
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
val
as
$type
})
Ok
(
unsafe
{
val
.value
.
$field
as
$type
})
}
}
}
}
impl
<
'a
>
TryFrom
<&
TVMArgValue
<
'a
>>
for
$type
{
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$type
{
type
Error
=
Error
;
type
Error
=
$crate
::
errors
::
ValueDowncastError
;
fn
try_from
(
val
:
&
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
val
,
$type_code
);
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
*
val
as
$type
})
Ok
(
unsafe
{
val
.value
.
$field
as
$type
})
}
}
impl
From
<
$type
>
for
TVMRetValue
{
fn
from
(
val
:
$type
)
->
Self
{
Self
::
$variant
(
val
as
$inner_ty
)
}
}
impl
TryFrom
<
TVMRetValue
>
for
$type
{
type
Error
=
$crate
::
errors
::
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
Self
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMRetValue
::
$variant
(
val
)|
{
val
as
$type
})
}
}
}
}
)
+
)
+
};
};
}
}
impl_prim_tvm_arg!
(
DLDataTypeCode_kDLFloat
,
v_float64
,
f64
,
[
f32
,
f64
]);
impl_pod_value!
(
Int
,
i64
,
[
i8
,
i16
,
i32
,
i64
,
isize
]);
impl_prim_tvm_arg!
(
impl_pod_value!
(
UInt
,
i64
,
[
u8
,
u16
,
u32
,
u64
,
usize
]);
DLDataTypeCode_kDLInt
,
impl_pod_value!
(
Float
,
f64
,
[
f32
,
f64
]);
v_int64
,
impl_pod_value!
(
Type
,
TVMType
,
[
TVMType
]);
i64
,
impl_pod_value!
(
Context
,
TVMContext
,
[
TVMContext
]);
[
i8
,
i16
,
i32
,
i64
,
isize
]
);
impl_prim_tvm_arg!
(
DLDataTypeCode_kDLUInt
,
v_int64
,
i64
,
[
u8
,
u16
,
u32
,
u64
,
usize
]
);
#[cfg(feature
=
"bindings"
)]
impl
<
'a
>
From
<&
'a
str
>
for
TVMArgValue
<
'a
>
{
// only allow this in bindings because pure-rust can't take ownership of leaked CString
fn
from
(
s
:
&
'a
str
)
->
Self
{
impl
<
'a
>
From
<&
String
>
for
TVMArgValue
<
'a
>
{
Self
::
String
(
CString
::
new
(
s
)
.unwrap
())
fn
from
(
string
:
&
String
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_str
:
std
::
ffi
::
CString
::
new
(
string
.clone
())
.unwrap
()
.into_raw
(),
},
type_code
:
TVMTypeCode_kStr
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
>
From
<&
std
::
ffi
::
CString
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
>
From
<&
'a
CStr
>
for
TVMArgValue
<
'a
>
{
fn
from
(
string
:
&
std
::
ffi
::
CString
)
->
Self
{
fn
from
(
s
:
&
'a
CStr
)
->
Self
{
TVMArgValue
{
Self
::
Str
(
s
)
value
:
TVMValue
{
v_str
:
string
.as_ptr
(),
},
type_code
:
TVMTypeCode_kStr
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
&
str
{
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
&
'a
str
{
type
Error
=
Error
;
type
Error
=
ValueDowncastError
;
fn
try_from
(
arg
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
TVMTypeCode_kStr
);
try_downcast!
(
val
->
&
str
,
|
TVMArgValue
::
Str
(
s
)|
{
s
.to_str
()
.unwrap
()
})
Ok
(
unsafe
{
std
::
ffi
::
CStr
::
from_ptr
(
arg
.value.v_handle
as
*
const
i8
)
}
.to_str
()
?
)
}
}
}
}
impl
<
'a
>
TryFrom
<&
TVMArgValue
<
'a
>>
for
&
str
{
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
&
'v
str
{
type
Error
=
Error
;
type
Error
=
ValueDowncastError
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
TVMTypeCode_kStr
);
try_downcast!
(
val
->
&
str
,
|
TVMArgValue
::
Str
(
s
)|
{
s
.to_str
()
.unwrap
()
})
Ok
(
unsafe
{
std
::
ffi
::
CStr
::
from_ptr
(
arg
.value.v_handle
as
*
const
i8
)
}
.to_str
()
?
)
}
}
}
}
/// C
reates a conversion to a `TVMArgValue` for an object handl
e.
/// C
onverts an unspecialized handle to a TVMArgValu
e.
impl
<
'a
,
T
>
From
<*
const
T
>
for
TVMArgValue
<
'a
>
{
impl
<
T
>
From
<*
const
T
>
for
TVMArgValue
<
'static
>
{
fn
from
(
ptr
:
*
const
T
)
->
Self
{
fn
from
(
ptr
:
*
const
T
)
->
Self
{
TVMArgValue
{
Self
::
Handle
(
ptr
as
*
mut
c_void
)
value
:
TVMValue
{
v_handle
:
ptr
as
*
mut
T
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kArrayHandle
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
/// C
reates a conversion to a `TVMArgValue` for a mutable object handl
e.
/// C
onverts an unspecialized mutable handle to a TVMArgValu
e.
impl
<
'a
,
T
>
From
<*
mut
T
>
for
TVMArgValue
<
'a
>
{
impl
<
T
>
From
<*
mut
T
>
for
TVMArgValue
<
'static
>
{
fn
from
(
ptr
:
*
mut
T
)
->
Self
{
fn
from
(
ptr
:
*
mut
T
)
->
Self
{
TVMArgValue
{
Self
::
Handle
(
ptr
as
*
mut
c_void
)
value
:
TVMValue
{
v_handle
:
ptr
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kHandle
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
>
From
<&
'a
mut
DLTensor
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
>
From
<&
'a
mut
DLTensor
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
'a
mut
DLTensor
)
->
Self
{
fn
from
(
arr
:
&
'a
mut
DLTensor
)
->
Self
{
TVMArgValue
{
Self
::
ArrayHandle
(
arr
as
*
mut
DLTensor
)
value
:
TVMValue
{
v_handle
:
arr
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kArrayHandle
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
>
From
<&
'a
DLTensor
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
>
From
<&
'a
DLTensor
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
'a
DLTensor
)
->
Self
{
fn
from
(
arr
:
&
'a
DLTensor
)
->
Self
{
TVMArgValue
{
Self
::
ArrayHandle
(
arr
as
*
const
_
as
*
mut
DLTensor
)
value
:
TVMValue
{
v_handle
:
arr
as
*
const
_
as
*
mut
DLTensor
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kArrayHandle
as
i64
,
_lifetime
:
PhantomData
,
}
}
}
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
TVMType
{
impl
TryFrom
<
TVMRetValue
>
for
String
{
type
Error
=
Error
;
type
Error
=
ValueDowncastError
;
fn
try_from
(
arg
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
String
,
Self
::
Error
>
{
ensure_type!
(
arg
,
TVMTypeCode_kTVMType
);
try_downcast!
(
Ok
(
unsafe
{
arg
.value.v_type
.into
()
})
val
->
String
,
|
TVMRetValue
::
String
(
s
)|
{
s
.into_string
()
.unwrap
()
},
|
TVMRetValue
::
Str
(
s
)|
{
s
.to_str
()
.unwrap
()
.to_string
()
}
)
}
}
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
impl
From
<
String
>
for
TVMRetValue
{
/// Can be downcasted using `try_from` if it contains the desired type.
fn
from
(
s
:
String
)
->
Self
{
///
Self
::
String
(
std
::
ffi
::
CString
::
new
(
s
)
.unwrap
())
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub
struct
TVMRetValue
{
pub
value
:
TVMValue
,
pub
box_value
:
Box
<
Any
>
,
pub
type_code
:
i64
,
}
impl
TVMRetValue
{
pub
fn
from_tvm_value
(
value
:
TVMValue
,
type_code
:
i64
)
->
Self
{
Self
{
value
,
type_code
,
box_value
:
box
(),
}
}
pub
fn
into_tvm_value
(
self
)
->
(
TVMValue
,
TVMTypeCode
)
{
(
self
.value
,
self
.type_code
as
TVMTypeCode
)
}
}
}
}
impl
Default
for
TVMRetValue
{
impl
From
<
TVMByteArray
>
for
TVMRetValue
{
fn
default
()
->
Self
{
fn
from
(
arr
:
TVMByteArray
)
->
Self
{
TVMRetValue
{
Self
::
Bytes
(
arr
)
value
:
TVMValue
{
v_int64
:
0
as
i64
},
type_code
:
0
,
box_value
:
box
(),
}
}
}
}
}
macro_rules!
impl_pod_ret_value
{
impl
TryFrom
<
TVMRetValue
>
for
TVMByteArray
{
(
$code:expr
,
[
$
(
$ty:ty
),
+
]
)
=>
{
type
Error
=
ValueDowncastError
;
$
(
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
Self
,
Self
::
Error
>
{
impl
From
<
$ty
>
for
TVMRetValue
{
try_downcast!
(
val
->
TVMByteArray
,
|
TVMRetValue
::
Bytes
(
val
)|
{
val
})
fn
from
(
val
:
$ty
)
->
Self
{
Self
{
value
:
val
.into
(),
type_code
:
$code
as
i64
,
box_value
:
box
(),
}
}
}
impl
TryFrom
<
TVMRetValue
>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
ret
,
$code
);
Ok
(
ret
.value
.into
())
}
}
)
+
};
}
impl_pod_ret_value!
(
DLDataTypeCode_kDLInt
,
[
i8
,
i16
,
i32
,
i64
,
isize
]);
impl_pod_ret_value!
(
DLDataTypeCode_kDLUInt
,
[
u8
,
u16
,
u32
,
u64
,
usize
]);
impl_pod_ret_value!
(
DLDataTypeCode_kDLFloat
,
[
f32
,
f64
]);
impl_pod_ret_value!
(
TVMTypeCode_kTVMType
,
[
TVMType
]);
impl_pod_ret_value!
(
TVMTypeCode_kTVMContext
,
[
TVMContext
]);
impl
TryFrom
<
TVMRetValue
>
for
String
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
String
,
Self
::
Error
>
{
ensure_type!
(
ret
,
TVMTypeCode_kStr
);
let
cs
=
unsafe
{
std
::
ffi
::
CString
::
from_raw
(
ret
.value.v_handle
as
*
mut
i8
)
};
let
ret_str
=
cs
.clone
()
.into_string
();
if
cfg!
(
feature
=
"bindings"
)
{
std
::
mem
::
forget
(
cs
);
// TVM C++ takes ownership of CString. (@see TVMFuncCall)
}
Ok
(
ret_str
?
)
}
}
}
}
impl
From
<
String
>
for
TVMRetValue
{
impl
Default
for
TVMRetValue
{
fn
from
(
s
:
String
)
->
Self
{
fn
default
()
->
Self
{
let
cs
=
std
::
ffi
::
CString
::
new
(
s
)
.unwrap
();
Self
::
Int
(
0
)
Self
{
value
:
TVMValue
{
v_str
:
cs
.into_raw
()
as
*
mut
i8
,
},
box_value
:
box
(),
type_code
:
TVMTypeCode_kStr
as
i64
,
}
}
}
}
}
rust/common/src/value.rs
View file @
14a0ecba
...
@@ -137,3 +137,18 @@ impl_tvm_context!(
...
@@ -137,3 +137,18 @@ impl_tvm_context!(
DLDeviceType_kDLROCM
:
[
rocm
],
DLDeviceType_kDLROCM
:
[
rocm
],
DLDeviceType_kDLExtDev
:
[
ext_dev
]
DLDeviceType_kDLExtDev
:
[
ext_dev
]
);
);
impl
TVMByteArray
{
pub
fn
data
(
&
self
)
->
&
'static
[
u8
]
{
unsafe
{
std
::
slice
::
from_raw_parts
(
self
.data
as
*
const
u8
,
self
.size
)
}
}
}
impl
<
'a
>
From
<&
'a
[
u8
]
>
for
TVMByteArray
{
fn
from
(
bytes
:
&
[
u8
])
->
Self
{
Self
{
data
:
bytes
.as_ptr
()
as
*
const
i8
,
size
:
bytes
.len
(),
}
}
}
rust/frontend/Cargo.toml
View file @
14a0ecba
...
@@ -9,6 +9,7 @@ readme = "README.md"
...
@@ -9,6 +9,7 @@ readme = "README.md"
keywords
=
[
"rust"
,
"tvm"
,
"nnvm"
]
keywords
=
[
"rust"
,
"tvm"
,
"nnvm"
]
categories
=
[
"api-bindings"
,
"science"
]
categories
=
[
"api-bindings"
,
"science"
]
authors
=
[
"TVM Contributors"
]
authors
=
[
"TVM Contributors"
]
edition
=
"2018"
[lib]
[lib]
name
=
"tvm_frontend"
name
=
"tvm_frontend"
...
...
rust/frontend/src/bytearray.rs
View file @
14a0ecba
...
@@ -3,9 +3,9 @@
...
@@ -3,9 +3,9 @@
//!
//!
//! For more detail, please see the example `resnet` in `examples` repository.
//! For more detail, please see the example `resnet` in `examples` repository.
use
std
::
os
::
raw
::
{
c_char
,
c_void
}
;
use
std
::
os
::
raw
::
c_char
;
use
tvm_common
::
{
ffi
,
TVMArgValue
}
;
use
tvm_common
::
ffi
;
/// A struct holding TVM byte-array.
/// A struct holding TVM byte-array.
///
///
...
@@ -44,8 +44,9 @@ impl TVMByteArray {
...
@@ -44,8 +44,9 @@ impl TVMByteArray {
}
}
}
}
impl
<
'a
>
From
<&
'a
Vec
<
u8
>>
for
TVMByteArray
{
impl
<
'a
,
T
:
AsRef
<
[
u8
]
>>
From
<
T
>
for
TVMByteArray
{
fn
from
(
arg
:
&
Vec
<
u8
>
)
->
Self
{
fn
from
(
arg
:
T
)
->
Self
{
let
arg
=
arg
.as_ref
();
let
barr
=
ffi
::
TVMByteArray
{
let
barr
=
ffi
::
TVMByteArray
{
data
:
arg
.as_ptr
()
as
*
const
c_char
,
data
:
arg
.as_ptr
()
as
*
const
c_char
,
size
:
arg
.len
(),
size
:
arg
.len
(),
...
@@ -54,18 +55,6 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
...
@@ -54,18 +55,6 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
}
}
}
}
impl
<
'a
>
From
<&
TVMByteArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
TVMByteArray
)
->
Self
{
Self
{
value
:
ffi
::
TVMValue
{
v_handle
:
&
arr
.inner
as
*
const
ffi
::
TVMByteArray
as
*
const
c_void
as
*
mut
c_void
,
},
type_code
:
ffi
::
TVMTypeCode_kBytes
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
...
...
rust/frontend/src/context.rs
View file @
14a0ecba
...
@@ -26,10 +26,7 @@ use std::{
...
@@ -26,10 +26,7 @@ use std::{
use
failure
::
Error
;
use
failure
::
Error
;
use
tvm_common
::{
use
tvm_common
::
ffi
;
ffi
::{
self
,
TVMValue
},
TVMArgValue
,
};
use
crate
::
function
;
use
crate
::
function
;
...
@@ -125,18 +122,6 @@ impl<'a> From<&'a str> for TVMDeviceType {
...
@@ -125,18 +122,6 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
}
}
impl
<
'a
>
From
<&
'a
TVMDeviceType
>
for
TVMArgValue
<
'a
>
{
fn
from
(
dev_type
:
&
'a
TVMDeviceType
)
->
Self
{
Self
{
value
:
TVMValue
{
v_int64
:
dev_type
.
0
as
i64
,
},
type_code
:
ffi
::
DLDataTypeCode_kDLInt
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
/// Represents the underlying device context. Default is cpu.
/// Represents the underlying device context. Default is cpu.
///
///
/// ## Examples
/// ## Examples
...
@@ -209,7 +194,7 @@ impl TVMContext {
...
@@ -209,7 +194,7 @@ impl TVMContext {
let
dt
=
self
.device_type
.
0
as
usize
;
let
dt
=
self
.device_type
.
0
as
usize
;
// `unwrap` is ok here because if there is any error,
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
// if would occure inside `call_packed!`
let
ret
:
u64
=
call_packed!
(
func
,
&
dt
,
&
self
.device_id
,
&
0
)
let
ret
:
u64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
.unwrap
()
.unwrap
()
.try_into
()
.try_into
()
.unwrap
();
.unwrap
();
...
@@ -238,7 +223,9 @@ macro_rules! impl_device_attrs {
...
@@ -238,7 +223,9 @@ macro_rules! impl_device_attrs {
// `unwrap` is ok here because if there is any error,
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
// if would occur in function call.
function
::
Builder
::
from
(
func
)
function
::
Builder
::
from
(
func
)
.args
(
&
[
dt
,
self
.device_id
as
usize
,
$attr_kind
])
.arg
(
dt
)
.arg
(
self
.device_id
as
usize
)
.arg
(
$attr_kind
)
.invoke
()
.invoke
()
.unwrap
()
.unwrap
()
.try_into
()
.try_into
()
...
...
rust/frontend/src/function.rs
View file @
14a0ecba
...
@@ -156,9 +156,9 @@ impl<'a, 'm> Builder<'a, 'm> {
...
@@ -156,9 +156,9 @@ impl<'a, 'm> Builder<'a, 'm> {
}
}
/// Pushes a [`TVMArgValue`] into the function argument buffer.
/// Pushes a [`TVMArgValue`] into the function argument buffer.
pub
fn
arg
<
T
:
'a
>
(
&
mut
self
,
arg
:
&
'a
T
)
->
&
mut
Self
pub
fn
arg
<
T
:
'a
>
(
&
mut
self
,
arg
:
T
)
->
&
mut
Self
where
where
TVMArgValue
<
'a
>
:
From
<
&
'a
T
>
,
TVMArgValue
<
'a
>
:
From
<
T
>
,
{
{
self
.arg_buf
.push
(
arg
.into
());
self
.arg_buf
.push
(
arg
.into
());
self
self
...
@@ -192,14 +192,11 @@ impl<'a, 'm> Builder<'a, 'm> {
...
@@ -192,14 +192,11 @@ impl<'a, 'm> Builder<'a, 'm> {
ensure!
(
self
.func
.is_some
(),
errors
::
FunctionNotFoundError
);
ensure!
(
self
.func
.is_some
(),
errors
::
FunctionNotFoundError
);
let
num_args
=
self
.arg_buf
.len
();
let
num_args
=
self
.arg_buf
.len
();
let
(
mut
values
,
mut
type_codes
):
(
Vec
<
ffi
::
TVMValue
>
,
Vec
<
ffi
::
TVMTypeCode
>
)
=
self
let
(
mut
values
,
mut
type_codes
):
(
Vec
<
ffi
::
TVMValue
>
,
Vec
<
ffi
::
TVMTypeCode
>
)
=
.arg_buf
self
.arg_buf
.iter
()
.map
(|
arg
|
arg
.to_tvm_value
())
.unzip
();
.iter
()
.map
(|
tvm_arg
|
(
tvm_arg
.value
,
tvm_arg
.type_code
as
ffi
::
TVMTypeCode
))
.unzip
();
let
mut
ret_val
=
unsafe
{
std
::
mem
::
uninitialized
::
<
TVMValue
>
()
};
let
mut
ret_val
=
unsafe
{
std
::
mem
::
uninitialized
::
<
TVMValue
>
()
};
let
mut
ret_type_code
=
0
;
let
mut
ret_type_code
=
0
i32
;
check_call!
(
ffi
::
TVMFuncCall
(
check_call!
(
ffi
::
TVMFuncCall
(
self
.func
.ok_or
(
errors
::
FunctionNotFoundError
)
?
.handle
,
self
.func
.ok_or
(
errors
::
FunctionNotFoundError
)
?
.handle
,
values
.as_mut_ptr
(),
values
.as_mut_ptr
(),
...
@@ -209,7 +206,7 @@ impl<'a, 'm> Builder<'a, 'm> {
...
@@ -209,7 +206,7 @@ impl<'a, 'm> Builder<'a, 'm> {
&
mut
ret_type_code
as
*
mut
_
&
mut
ret_type_code
as
*
mut
_
));
));
Ok
(
unsafe
{
TVMRetValue
::
from_tvm_value
(
ret_val
,
ret_type_code
as
i64
)
})
Ok
(
unsafe
{
TVMRetValue
::
from_tvm_value
(
ret_val
,
ret_type_code
as
u32
)
})
}
}
}
}
...
@@ -254,7 +251,7 @@ unsafe extern "C" fn tvm_callback(
...
@@ -254,7 +251,7 @@ unsafe extern "C" fn tvm_callback(
{
{
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
tcode
));
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
tcode
));
}
}
local_args
.push
(
TVMArgValue
::
new
(
value
.into
(),
(
tcode
as
i64
)
.into
()
));
local_args
.push
(
TVMArgValue
::
from_tvm_value
(
value
.into
(),
tcode
as
u32
));
}
}
let
rv
=
match
rust_fn
(
local_args
.as_slice
())
{
let
rv
=
match
rust_fn
(
local_args
.as_slice
())
{
...
@@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback(
...
@@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback(
}
}
};
};
let
(
mut
ret_val
,
ret_tcode
)
=
rv
.
in
to_tvm_value
();
let
(
mut
ret_val
,
ret_tcode
)
=
rv
.to_tvm_value
();
let
mut
ret_type_code
=
ret_tcode
as
c_int
;
let
mut
ret_type_code
=
ret_tcode
as
c_int
;
check_call!
(
ffi
::
TVMCFuncSetReturn
(
check_call!
(
ffi
::
TVMCFuncSetReturn
(
ret
,
ret
,
...
@@ -437,8 +434,9 @@ mod tests {
...
@@ -437,8 +434,9 @@ mod tests {
let
str_arg
=
CString
::
new
(
"test"
)
.unwrap
();
let
str_arg
=
CString
::
new
(
"test"
)
.unwrap
();
let
mut
func
=
Builder
::
default
();
let
mut
func
=
Builder
::
default
();
func
.get_function
(
"tvm.graph_runtime.remote_create"
)
func
.get_function
(
"tvm.graph_runtime.remote_create"
)
.args
(
&
[
10
,
20
])
.arg
(
10
)
.arg
(
&
str_arg
);
.arg
(
20
)
.arg
(
str_arg
.as_c_str
());
assert_eq!
(
func
.arg_buf
.len
(),
3
);
assert_eq!
(
func
.arg_buf
.len
(),
3
);
}
}
}
}
rust/frontend/src/module.rs
View file @
14a0ecba
...
@@ -80,7 +80,7 @@ impl Module {
...
@@ -80,7 +80,7 @@ impl Module {
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
})
?
)
?
;
})
?
)
?
;
let
ret
:
Module
=
call_packed!
(
func
,
&
cpath
,
&
ext
)
?
.try_into
()
?
;
let
ret
:
Module
=
call_packed!
(
func
,
cpath
.as_c_str
(),
ext
.as_c_str
()
)
?
.try_into
()
?
;
Ok
(
ret
)
Ok
(
ret
)
}
}
...
@@ -90,7 +90,10 @@ impl Module {
...
@@ -90,7 +90,10 @@ impl Module {
// `unwrap` is safe here because if there is any error during the
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
// function call, it would occur in `call_packed!`.
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
let
ret
:
i64
=
call_packed!
(
func
,
&
tgt
)
.unwrap
()
.try_into
()
.unwrap
();
let
ret
:
i64
=
call_packed!
(
func
,
tgt
.as_c_str
())
.unwrap
()
.try_into
()
.unwrap
();
ret
!=
0
ret
!=
0
}
}
...
...
rust/frontend/src/ndarray.rs
View file @
14a0ecba
...
@@ -161,7 +161,7 @@ impl NDArray {
...
@@ -161,7 +161,7 @@ impl NDArray {
/// Converts the NDArray to [`TVMByteArray`].
/// Converts the NDArray to [`TVMByteArray`].
pub
fn
to_bytearray
(
&
self
)
->
Result
<
TVMByteArray
,
Error
>
{
pub
fn
to_bytearray
(
&
self
)
->
Result
<
TVMByteArray
,
Error
>
{
let
v
=
self
.to_vec
::
<
u8
>
()
?
;
let
v
=
self
.to_vec
::
<
u8
>
()
?
;
Ok
(
TVMByteArray
::
from
(
&
v
))
Ok
(
TVMByteArray
::
from
(
v
))
}
}
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
...
...
rust/frontend/src/value.rs
View file @
14a0ecba
...
@@ -2,140 +2,80 @@
...
@@ -2,140 +2,80 @@
//! and their conversions needed for the types used in frontend crate.
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use
std
::
{
convert
::
TryFrom
,
os
::
raw
::
c_void
}
;
use
std
::
convert
::
TryFrom
;
use
failure
::
Error
;
use
tvm_common
::{
use
tvm_common
::{
ensure_type
,
errors
::
ValueDowncastError
,
ffi
::{
self
,
TVMValue
},
ffi
::{
TVMArrayHandle
,
TVMFunctionHandle
,
TVMModuleHandle
},
try_downcast
,
};
};
use
crate
::{
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
common_errors
::
*
,
context
::
TVMContext
,
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMByteArray
,
TVMRetValue
,
};
macro_rules!
impl_tvm_val_from_handle
{
macro_rules!
impl_handle_val
{
(
$ty:ident
,
$type_code:expr
,
$handle:ty
)
=>
{
(
$type:ty
,
$variant:ident
,
$inner_type:ty
,
$ctor:path
)
=>
{
impl
<
'a
>
From
<&
'a
$ty
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
$ty
)
->
Self
{
fn
from
(
arg
:
&
'a
$type
)
->
Self
{
TVMArgValue
{
TVMArgValue
::
$variant
(
arg
.handle
()
as
$inner_type
)
value
:
TVMValue
{
v_handle
:
arg
.handle
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
}
}
impl
<
'a
>
From
<&
'a
mut
$ty
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
>
From
<&
'a
mut
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
mut
$ty
)
->
Self
{
fn
from
(
arg
:
&
'a
mut
$type
)
->
Self
{
TVMArgValue
{
TVMArgValue
::
$variant
(
arg
.handle
()
as
$inner_type
)
value
:
TVMValue
{
v_handle
:
arg
.handle
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$ty
{
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
$type
{
type
Error
=
Error
;
type
Error
=
ValueDowncastError
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'v
>
)
->
Result
<
$ty
,
Self
::
Error
>
{
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
$type
,
Self
::
Error
>
{
ensure_type!
(
arg
,
$type_code
);
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
$ctor
(
val
)
})
Ok
(
$ty
::
new
(
unsafe
{
arg
.value.v_handle
as
$handle
}))
}
}
}
}
impl
From
<
$ty
>
for
TVMRetValue
{
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$type
{
fn
from
(
val
:
$ty
)
->
TVMRetValue
{
type
Error
=
ValueDowncastError
;
TVMRetValue
{
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
$type
,
Self
::
Error
>
{
value
:
TVMValue
{
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
$ctor
(
*
val
)
})
v_handle
:
val
.handle
()
as
*
mut
c_void
,
},
box_value
:
box
val
,
type_code
:
$type_code
as
i64
,
}
}
}
}
}
impl
TryFrom
<
TVMRetValue
>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
ret
,
$type_code
);
Ok
(
$ty
::
new
(
unsafe
{
ret
.value.v_handle
as
$handle
}))
}
}
};
}
impl_tvm_val_from_handle!
(
Function
,
ffi
::
TVMTypeCode_kFuncHandle
,
ffi
::
TVMFunctionHandle
);
impl_tvm_val_from_handle!
(
Module
,
ffi
::
TVMTypeCode_kModuleHandle
,
ffi
::
TVMModuleHandle
);
impl_tvm_val_from_handle!
(
NDArray
,
ffi
::
TVMTypeCode_kArrayHandle
,
ffi
::
TVMArrayHandle
);
impl
<
'a
>
From
<&
'a
TVMByteArray
>
for
TVMValue
{
fn
from
(
barr
:
&
TVMByteArray
)
->
Self
{
TVMValue
{
v_handle
:
&
barr
.inner
as
*
const
ffi
::
TVMByteArray
as
*
mut
c_void
,
}
}
}
macro_rules!
impl_boxed_ret_value
{
(
$type:ty
,
$code:expr
)
=>
{
impl
From
<
$type
>
for
TVMRetValue
{
impl
From
<
$type
>
for
TVMRetValue
{
fn
from
(
val
:
$type
)
->
Self
{
fn
from
(
val
:
$type
)
->
TVMRetValue
{
TVMRetValue
{
TVMRetValue
::
$variant
(
val
.handle
()
as
$inner_type
)
value
:
TVMValue
{
v_int64
:
0
},
box_value
:
box
val
,
type_code
:
$code
as
i64
,
}
}
}
}
}
impl
TryFrom
<
TVMRetValue
>
for
$type
{
impl
TryFrom
<
TVMRetValue
>
for
$type
{
type
Error
=
Error
;
type
Error
=
ValueDowncastError
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$type
,
Self
::
Error
>
{
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
$type
,
Self
::
Error
>
{
if
let
Ok
(
val
)
=
ret
.box_value.downcast
::
<
$type
>
()
{
try_downcast!
(
val
->
$type
,
|
TVMRetValue
::
$variant
(
val
)|
{
$ctor
(
val
)
})
Ok
(
*
val
)
}
else
{
bail!
(
ValueDowncastError
::
new
(
$code
as
i64
,
ret
.type_code
as
i64
))
}
}
}
}
}
};
};
}
}
impl_boxed_ret_value!
(
TVMContext
,
ffi
::
TVMTypeCode_kTVMContext
);
impl_handle_val!
(
Function
,
FuncHandle
,
TVMFunctionHandle
,
Function
::
new
);
impl_boxed_ret_value!
(
TVMByteArray
,
ffi
::
TVMTypeCode_kBytes
);
impl_handle_val!
(
Module
,
ModuleHandle
,
TVMModuleHandle
,
Module
::
new
);
impl_handle_val!
(
NDArray
,
ArrayHandle
,
TVMArrayHandle
,
NDArray
::
new
);
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
TVMByteArray
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
ffi
::
TVMTypeCode_kBytes
);
Ok
(
TVMByteArray
::
new
(
unsafe
{
*
(
arg
.value.v_handle
as
*
mut
ffi
::
TVMByteArray
)
}))
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
std
::{
convert
::
TryInto
,
str
::
FromStr
};
use
std
::{
convert
::
TryInto
,
str
::
FromStr
};
use
tvm_common
::
ffi
::
TVMType
;
use
tvm_common
::{
TVMByteArray
,
TVMContext
,
TVMType
};
use
super
::
*
;
#[test]
#[test]
fn
bytearray
()
{
fn
bytearray
()
{
let
w
=
vec!
[
1u8
,
2
,
3
,
4
,
5
];
let
w
=
vec!
[
1u8
,
2
,
3
,
4
,
5
];
let
v
=
TVMByteArray
::
from
(
&
w
);
let
v
=
TVMByteArray
::
from
(
w
.as_slice
()
);
let
tvm
:
TVMByteArray
=
TVMRetValue
::
from
(
v
)
.try_into
()
.unwrap
();
let
tvm
:
TVMByteArray
=
TVMRetValue
::
from
(
v
)
.try_into
()
.unwrap
();
assert_eq!
(
tvm
.data
(),
w
.iter
()
.map
(|
e
|
*
e
as
i8
)
.collect
::
<
Vec
<
i8
>>
());
assert_eq!
(
tvm
.data
(),
w
.iter
()
.map
(|
e
|
*
e
)
.collect
::
<
Vec
<
u8
>>
()
.as_slice
()
);
}
}
#[test]
#[test]
...
@@ -147,7 +87,7 @@ mod tests {
...
@@ -147,7 +87,7 @@ mod tests {
#[test]
#[test]
fn
ctx
()
{
fn
ctx
()
{
let
c
=
TVMContext
::
from
(
"gpu"
);
let
c
=
TVMContext
::
from
_str
(
"gpu"
)
.unwrap
(
);
let
tvm
:
TVMContext
=
TVMRetValue
::
from
(
c
)
.try_into
()
.unwrap
();
let
tvm
:
TVMContext
=
TVMRetValue
::
from
(
c
)
.try_into
()
.unwrap
();
assert_eq!
(
tvm
,
c
);
assert_eq!
(
tvm
,
c
);
}
}
...
...
rust/frontend/tests/callback/src/bin/string.rs
View file @
14a0ecba
...
@@ -24,9 +24,9 @@ fn main() {
...
@@ -24,9 +24,9 @@ fn main() {
registered
.get_function
(
"concate_str"
);
registered
.get_function
(
"concate_str"
);
assert
!
(
registered
.func
.is_some
());
assert
!
(
registered
.func
.is_some
());
let
ret
:
String
=
registered
let
ret
:
String
=
registered
.arg
(
&
a
)
.arg
(
a
.as_c_str
()
)
.arg
(
&
b
)
.arg
(
b
.as_c_str
()
)
.arg
(
&
c
)
.arg
(
c
.as_c_str
()
)
.invoke
()
.invoke
()
.unwrap
()
.unwrap
()
.try_into
()
.try_into
()
...
...
rust/runtime/Cargo.toml
View file @
14a0ecba
...
@@ -8,6 +8,7 @@ readme = "README.md"
...
@@ -8,6 +8,7 @@ readme = "README.md"
keywords
=
[
"tvm"
,
"nnvm"
]
keywords
=
[
"tvm"
,
"nnvm"
]
categories
=
[
"api-bindings"
,
"science"
]
categories
=
[
"api-bindings"
,
"science"
]
authors
=
[
"TVM Contributors"
]
authors
=
[
"TVM Contributors"
]
edition
=
"2018"
[features]
[features]
default
=
["nom/std"]
default
=
["nom/std"]
...
...
rust/runtime/src/graph.rs
View file @
14a0ecba
...
@@ -265,7 +265,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
...
@@ -265,7 +265,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
.iter
()
.iter
()
.map
(|
t
|
t
.into
())
.map
(|
t
|
t
.into
())
.collect
::
<
Vec
<
TVMArgValue
>>
();
.collect
::
<
Vec
<
TVMArgValue
>>
();
func
(
args
.as_slice
()
)
.unwrap
();
func
(
&
args
)
.unwrap
();
};
};
op_execs
.push
(
op
);
op_execs
.push
(
op
);
}
}
...
@@ -283,7 +283,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
...
@@ -283,7 +283,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
// TODO: consider `new_with_params` to avoid ever allocating
// TODO: consider `new_with_params` to avoid ever allocating
let
ptr
=
self
.tensors
[
idx
]
.data
.as_ptr
();
let
ptr
=
self
.tensors
[
idx
]
.data
.as_ptr
();
let
mut
to_replace
=
self
.tensors
.iter_mut
()
.filter
(|
t
|
t
.data
.as_ptr
()
==
ptr
);
let
mut
to_replace
=
self
.tensors
.iter_mut
()
.filter
(|
t
|
t
.data
.as_ptr
()
==
ptr
);
let
mut
owner
=
to_replace
.nth
(
0
)
.unwrap
();
let
owner
=
to_replace
.nth
(
0
)
.unwrap
();
if
value
.data
.is_owned
()
{
if
value
.data
.is_owned
()
{
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
// mem::replace(&mut (*owner), value);
// mem::replace(&mut (*owner), value);
...
...
rust/runtime/src/module.rs
View file @
14a0ecba
...
@@ -40,17 +40,14 @@ pub(super) fn wrap_backend_packed_func(
...
@@ -40,17 +40,14 @@ pub(super) fn wrap_backend_packed_func(
func
:
BackendPackedCFunc
,
func
:
BackendPackedCFunc
,
)
->
Box
<
dyn
PackedFunc
>
{
)
->
Box
<
dyn
PackedFunc
>
{
box
move
|
args
:
&
[
TVMArgValue
]|
{
box
move
|
args
:
&
[
TVMArgValue
]|
{
let
exit_code
=
func
(
let
(
values
,
type_codes
):
(
Vec
<
TVMValue
>
,
Vec
<
i32
>
)
=
args
args
.iter
()
.into_iter
()
.map
(|
ref
arg
|
arg
.value
)
.map
(|
arg
|
{
.collect
::
<
Vec
<
TVMValue
>>
()
let
(
val
,
code
)
=
arg
.to_tvm_value
();
.as_ptr
(),
(
val
,
code
as
i32
)
args
.iter
()
})
.map
(|
ref
arg
|
arg
.type_code
as
i32
)
.unzip
();
.collect
::
<
Vec
<
i32
>>
()
let
exit_code
=
func
(
values
.as_ptr
(),
type_codes
.as_ptr
(),
values
.len
()
as
i32
);
.as_ptr
()
as
*
const
i32
,
args
.len
()
as
i32
,
);
if
exit_code
==
0
{
if
exit_code
==
0
{
Ok
(
TVMRetValue
::
default
())
Ok
(
TVMRetValue
::
default
())
}
else
{
}
else
{
...
...
rust/runtime/tests/test_graph_serde.rs
View file @
14a0ecba
#
!
[
feature
(
try_from
)]
extern
crate
serde
;
extern
crate
serde
;
extern
crate
serde_json
;
extern
crate
serde_json
;
...
...
rust/runtime/tests/test_nnvm/src/main.rs
View file @
14a0ecba
#
!
[
feature
(
try_from
)]
#
[
macro_use
]
#
[
macro_use
]
extern
crate
ndarray
;
extern
crate
ndarray
;
extern
crate
serde
;
extern
crate
serde
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment