Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
14a0ecba
Commit
14a0ecba
authored
Apr 07, 2019
by
Nick Hynes
Committed by
Tianqi Chen
Apr 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rustify PackedFunc & Friends (#2969)
parent
0708c48d
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
341 additions
and
452 deletions
+341
-452
rust/common/Cargo.toml
+1
-0
rust/common/src/errors.rs
+6
-57
rust/common/src/lib.rs
+2
-2
rust/common/src/packed_func.rs
+235
-223
rust/common/src/value.rs
+15
-0
rust/frontend/Cargo.toml
+1
-0
rust/frontend/src/bytearray.rs
+5
-16
rust/frontend/src/context.rs
+5
-18
rust/frontend/src/function.rs
+11
-13
rust/frontend/src/module.rs
+5
-2
rust/frontend/src/ndarray.rs
+1
-1
rust/frontend/src/value.rs
+40
-100
rust/frontend/tests/callback/src/bin/string.rs
+3
-3
rust/runtime/Cargo.toml
+1
-0
rust/runtime/src/graph.rs
+2
-2
rust/runtime/src/module.rs
+8
-11
rust/runtime/tests/test_graph_serde.rs
+0
-2
rust/runtime/tests/test_nnvm/src/main.rs
+0
-2
No files found.
rust/common/Cargo.toml
View file @
14a0ecba
...
...
@@ -3,6 +3,7 @@ name = "tvm-common"
version
=
"0.1.0"
authors
=
[
"TVM Contributors"
]
license
=
"Apache-2.0"
edition
=
"2018"
[features]
bindings
=
[]
...
...
rust/common/src/errors.rs
View file @
14a0ecba
use
std
::
fmt
;
static
TYPE_CODE_STRS
:
[
&
str
;
15
]
=
[
"int"
,
"uint"
,
"float"
,
"handle"
,
"null"
,
"TVMType"
,
"TVMContext"
,
"ArrayHandle"
,
"NodeHandle"
,
"ModuleHandle"
,
"FuncHandle"
,
"str"
,
"bytes"
,
"NDArrayContainer"
,
"ExtBegin"
,
];
#
[
derive
(
Debug
,
Fail
)]
#[fail(
display
=
"Could not downcast `{}` into `{}`"
,
expected_type,
actual_type
)]
pub
struct
ValueDowncastError
{
actual_type_code
:
i64
,
expected_type_code
:
i64
,
}
impl
ValueDowncastError
{
pub
fn
new
(
actual_type_code
:
i64
,
expected_type_code
:
i64
)
->
Self
{
Self
{
actual_type_code
,
expected_type_code
,
}
}
}
impl
fmt
::
Display
for
ValueDowncastError
{
fn
fmt
(
&
self
,
formatter
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
write!
(
formatter
,
"Could not downcast TVMValue: expected `{}` but was {}"
,
TYPE_CODE_STRS
[
self
.actual_type_code
as
usize
],
TYPE_CODE_STRS
[
self
.expected_type_code
as
usize
]
)
}
pub
actual_type
:
String
,
pub
expected_type
:
&
'static
str
,
}
#[derive(Debug,
Fail)]
...
...
@@ -62,18 +26,3 @@ impl FuncCallError {
}
}
}
// error_chain! {
// errors {
// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) {
// description("mismatched types while downcasting TVMRetValue")
// display("invalid downcast: expected `{}` but was `{}`",
// expected_type, type_code_to_string(actual_type_code))
// }
// }
// foreign_links {
// IntoString(std::ffi::IntoStringError);
// ParseInt(std::num::ParseIntError);
// Utf8(std::str::Utf8Error);
// }
// }
rust/common/src/lib.rs
View file @
14a0ecba
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#
!
[
feature
(
box_syntax
,
trait_alias
)]
#
!
[
feature
(
box_syntax
,
t
ype_alias_enum_variants
,
t
rait_alias
)]
#[macro_use]
extern
crate
failure
;
...
...
@@ -25,5 +25,5 @@ pub mod packed_func;
pub
mod
value
;
pub
use
errors
::
*
;
pub
use
ffi
::{
TVMContext
,
TVMType
};
pub
use
ffi
::{
TVM
ByteArray
,
TVM
Context
,
TVMType
};
pub
use
packed_func
::{
TVMArgValue
,
TVMRetValue
};
rust/common/src/packed_func.rs
View file @
14a0ecba
use
std
::{
any
::
Any
,
convert
::
TryFrom
,
marker
::
PhantomData
,
os
::
raw
::
c_void
};
use
failure
::
Error
;
use
std
::{
convert
::
TryFrom
,
ffi
::{
CStr
,
CString
},
os
::
raw
::
c_void
,
};
pub
use
crate
::
ffi
::
TVMValue
;
use
crate
::
ffi
::
*
;
use
crate
::
{
errors
::
ValueDowncastError
,
ffi
::
*
}
;
pub
trait
PackedFunc
=
Fn
(
&
[
TVMArgValue
])
->
Result
<
TVMRetValue
,
crate
::
errors
::
FuncCallError
>
+
Send
+
Sync
;
...
...
@@ -15,298 +17,308 @@ pub trait PackedFunc =
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules!
call_packed
{
(
$fn:expr
,
$
(
$args:expr
),
+
)
=>
{
$fn
(
&
[
$
(
$args
.into
(),)
+
])
};
(
$fn:expr
)
=>
{
$fn
(
&
Vec
::
new
())
};
(
$fn:expr
,
$
(
$args:expr
),
+
)
=>
{
$fn
(
&
[
$
(
$args
.into
(),)
+
])
};
(
$fn:expr
)
=>
{
$fn
(
&
Vec
::
new
())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone,
Copy)]
pub
struct
TVMArgValue
<
'a
>
{
pub
_lifetime
:
PhantomData
<&
'a
()
>
,
pub
value
:
TVMValue
,
pub
type_code
:
i64
,
/// Constructs a derivative of a TVMPodValue.
macro_rules!
TVMPODValue
{
{
$
(
#
[
$m:meta
])
+
$name:ident
$
(
<
$a:lifetime
>
)
?
{
$
(
$extra_variant:ident
(
$variant_type:ty
)
),
+
$
(,)
?
},
match
$value:ident
{
$
(
$tvm_type:ident
=>
{
$from_tvm_type:expr
})
+
},
match
&
self
{
$
(
$self_type:ident
(
$val:ident
)
=>
{
$from_self_type:expr
})
+
}
$
(,)
?
}
=>
{
$
(
#
[
$m
])
+
#[derive(Clone,
Debug)]
pub
enum
$name
$
(
<
$a
>
)
?
{
Int
(
i64
),
UInt
(
i64
),
Float
(
f64
),
Null
,
Type
(
TVMType
),
String
(
CString
),
Context
(
TVMContext
),
Handle
(
*
mut
c_void
),
ArrayHandle
(
TVMArrayHandle
),
NodeHandle
(
*
mut
c_void
),
ModuleHandle
(
TVMModuleHandle
),
FuncHandle
(
TVMFunctionHandle
),
NDArrayContainer
(
*
mut
c_void
),
$
(
$extra_variant
(
$variant_type
)),
+
}
impl
$
(
<
$a
>
)
?
$name
$
(
<
$a
>
)
?
{
pub
fn
from_tvm_value
(
$value
:
TVMValue
,
type_code
:
u32
)
->
Self
{
use
$name
::
*
;
#[allow(non_upper_case_globals)]
unsafe
{
match
type_code
{
DLDataTypeCode_kDLInt
=>
Int
(
$value
.v_int64
),
DLDataTypeCode_kDLUInt
=>
UInt
(
$value
.v_int64
),
DLDataTypeCode_kDLFloat
=>
Float
(
$value
.v_float64
),
TVMTypeCode_kNull
=>
Null
,
TVMTypeCode_kTVMType
=>
Type
(
$value
.v_type
),
TVMTypeCode_kTVMContext
=>
Context
(
$value
.v_ctx
),
TVMTypeCode_kHandle
=>
Handle
(
$value
.v_handle
),
TVMTypeCode_kArrayHandle
=>
ArrayHandle
(
$value
.v_handle
as
TVMArrayHandle
),
TVMTypeCode_kNodeHandle
=>
NodeHandle
(
$value
.v_handle
),
TVMTypeCode_kModuleHandle
=>
ModuleHandle
(
$value
.v_handle
),
TVMTypeCode_kFuncHandle
=>
FuncHandle
(
$value
.v_handle
),
TVMTypeCode_kNDArrayContainer
=>
NDArrayContainer
(
$value
.v_handle
),
$
(
$tvm_type
=>
{
$from_tvm_type
}
),
+
_
=>
unimplemented!
(
"{}"
,
type_code
),
}
}
}
pub
fn
to_tvm_value
(
&
self
)
->
(
TVMValue
,
TVMTypeCode
)
{
use
$name
::
*
;
match
self
{
Int
(
val
)
=>
(
TVMValue
{
v_int64
:
*
val
},
DLDataTypeCode_kDLInt
),
UInt
(
val
)
=>
(
TVMValue
{
v_int64
:
*
val
as
i64
},
DLDataTypeCode_kDLUInt
),
Float
(
val
)
=>
(
TVMValue
{
v_float64
:
*
val
},
DLDataTypeCode_kDLFloat
),
Null
=>
(
TVMValue
{
v_int64
:
0
},
TVMTypeCode_kNull
),
Type
(
val
)
=>
(
TVMValue
{
v_type
:
*
val
},
TVMTypeCode_kTVMType
),
Context
(
val
)
=>
(
TVMValue
{
v_ctx
:
val
.clone
()
},
TVMTypeCode_kTVMContext
),
String
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
.as_ptr
()
as
*
mut
c_void
},
TVMTypeCode_kStr
,
)
}
Handle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kHandle
),
ArrayHandle
(
val
)
=>
{
(
TVMValue
{
v_handle
:
*
val
as
*
const
_
as
*
mut
c_void
},
TVMTypeCode_kArrayHandle
,
)
},
NodeHandle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kNodeHandle
),
ModuleHandle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kModuleHandle
),
FuncHandle
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kFuncHandle
),
NDArrayContainer
(
val
)
=>
(
TVMValue
{
v_handle
:
*
val
},
TVMTypeCode_kNDArrayContainer
),
$
(
$self_type
(
$val
)
=>
{
$from_self_type
}
),
+
}
}
}
}
}
impl
<
'a
>
TVMArgValue
<
'a
>
{
pub
fn
new
(
value
:
TVMValue
,
type_code
:
i64
)
->
Self
{
TVMArgValue
{
_lifetime
:
PhantomData
,
value
:
value
,
type_code
:
type_code
,
TVMPODValue!
{
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
TVMArgValue
<
'a
>
{
Bytes
(
&
'a
TVMByteArray
),
Str
(
&
'a
CStr
),
},
match
value
{
TVMTypeCode_kBytes
=>
{
Bytes
(
&*
(
value
.v_handle
as
*
const
TVMByteArray
))
}
TVMTypeCode_kStr
=>
{
Str
(
CStr
::
from_ptr
(
value
.v_handle
as
*
const
i8
))
}
},
match
&
self
{
Bytes
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
.clone
()
as
*
const
_
as
*
mut
c_void
},
TVMTypeCode_kBytes
)
}
Str
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
.as_ptr
()
as
*
mut
c_void
},
TVMTypeCode_kStr
)}
}
}
TVMPODValue!
{
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
TVMRetValue
{
Bytes
(
TVMByteArray
),
Str
(
&
'static
CStr
),
},
match
value
{
TVMTypeCode_kBytes
=>
{
Bytes
(
*
(
value
.v_handle
as
*
const
TVMByteArray
))
}
TVMTypeCode_kStr
=>
{
Str
(
CStr
::
from_ptr
(
value
.v_handle
as
*
mut
i8
))
}
},
match
&
self
{
Bytes
(
val
)
=>
{
(
TVMValue
{
v_handle
:
val
as
*
const
_
as
*
mut
c_void
},
TVMTypeCode_kBytes
)
}
Str
(
val
)
=>
{
(
TVMValue
{
v_str
:
val
.as_ptr
()
},
TVMTypeCode_kStr
)
}
}
}
#[macro_export]
macro_rules!
ensure_type
{
(
$val:ident
,
$expected_type_code:expr
)
=>
{
ensure!
(
$
val
.type_code
==
$expected_type_code
as
i64
,
$crate
::
errors
::
ValueDowncastError
::
new
(
$val
.type_code
as
i64
,
$expected_type_code
as
i64
)
);
macro_rules!
try_downcast
{
(
$val:ident
->
$into:ty
,
$
(
|
$pat:pat
|
{
$converter:expr
}
),
+
)
=>
{
match
$val
{
$
(
$pat
=>
{
Ok
(
$converter
)
}
)
+
_
=>
Err
(
$crate
::
errors
::
ValueDowncastError
{
actual_type
:
format!
(
"{:?}"
,
$val
)
,
expected_type
:
stringify!
(
$into
),
}),
}
};
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules!
impl_p
rim_tvm_arg
{
(
$
type_code:ident
,
$field:ident
,
$field_type
:ty
,
[
$
(
$type:ty
),
+
]
)
=>
{
macro_rules!
impl_p
od_value
{
(
$
variant:ident
,
$inner_ty
:ty
,
[
$
(
$type:ty
),
+
]
)
=>
{
$
(
impl
From
<
$type
>
for
TVMArgValue
<
'static
>
{
impl
<
'a
>
From
<
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
val
:
$type
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
$field
:
val
as
$field_type
},
type_code
:
$type_code
as
i64
,
_lifetime
:
PhantomData
,
}
Self
::
$variant
(
val
as
$inner_ty
)
}
}
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
impl
<
'a
,
'v
>
From
<&
'a
$type
>
for
TVMArgValue
<
'v
>
{
fn
from
(
val
:
&
'a
$type
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
$field
:
val
.to_owned
()
as
$field_type
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
PhantomData
,
}
Self
::
$variant
(
*
val
as
$inner_ty
)
}
}
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
$type
{
type
Error
=
Error
;
type
Error
=
$crate
::
errors
::
ValueDowncast
Error
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
val
,
$type_code
);
Ok
(
unsafe
{
val
.value
.
$field
as
$type
})
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
val
as
$type
})
}
}
impl
<
'a
>
TryFrom
<&
TVMArgValue
<
'a
>>
for
$type
{
type
Error
=
Error
;
fn
try_from
(
val
:
&
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
val
,
$type_code
);
Ok
(
unsafe
{
val
.value
.
$field
as
$type
})
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$type
{
type
Error
=
$crate
::
errors
::
ValueDowncastError
;
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
*
val
as
$type
})
}
}
impl
From
<
$type
>
for
TVMRetValue
{
fn
from
(
val
:
$type
)
->
Self
{
Self
::
$variant
(
val
as
$inner_ty
)
}
}
impl
TryFrom
<
TVMRetValue
>
for
$type
{
type
Error
=
$crate
::
errors
::
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
Self
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMRetValue
::
$variant
(
val
)|
{
val
as
$type
})
}
}
)
+
};
}
impl_prim_tvm_arg!
(
DLDataTypeCode_kDLFloat
,
v_float64
,
f64
,
[
f32
,
f64
]);
impl_prim_tvm_arg!
(
DLDataTypeCode_kDLInt
,
v_int64
,
i64
,
[
i8
,
i16
,
i32
,
i64
,
isize
]
);
impl_prim_tvm_arg!
(
DLDataTypeCode_kDLUInt
,
v_int64
,
i64
,
[
u8
,
u16
,
u32
,
u64
,
usize
]
);
impl_pod_value!
(
Int
,
i64
,
[
i8
,
i16
,
i32
,
i64
,
isize
]);
impl_pod_value!
(
UInt
,
i64
,
[
u8
,
u16
,
u32
,
u64
,
usize
]);
impl_pod_value!
(
Float
,
f64
,
[
f32
,
f64
]);
impl_pod_value!
(
Type
,
TVMType
,
[
TVMType
]);
impl_pod_value!
(
Context
,
TVMContext
,
[
TVMContext
]);
#[cfg(feature
=
"bindings"
)]
// only allow this in bindings because pure-rust can't take ownership of leaked CString
impl
<
'a
>
From
<&
String
>
for
TVMArgValue
<
'a
>
{
fn
from
(
string
:
&
String
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_str
:
std
::
ffi
::
CString
::
new
(
string
.clone
())
.unwrap
()
.into_raw
(),
},
type_code
:
TVMTypeCode_kStr
as
i64
,
_lifetime
:
PhantomData
,
}
impl
<
'a
>
From
<&
'a
str
>
for
TVMArgValue
<
'a
>
{
fn
from
(
s
:
&
'a
str
)
->
Self
{
Self
::
String
(
CString
::
new
(
s
)
.unwrap
())
}
}
impl
<
'a
>
From
<&
std
::
ffi
::
CString
>
for
TVMArgValue
<
'a
>
{
fn
from
(
string
:
&
std
::
ffi
::
CString
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_str
:
string
.as_ptr
(),
},
type_code
:
TVMTypeCode_kStr
as
i64
,
_lifetime
:
PhantomData
,
}
impl
<
'a
>
From
<&
'a
CStr
>
for
TVMArgValue
<
'a
>
{
fn
from
(
s
:
&
'a
CStr
)
->
Self
{
Self
::
Str
(
s
)
}
}
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
&
str
{
type
Error
=
Error
;
fn
try_from
(
arg
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
TVMTypeCode_kStr
);
Ok
(
unsafe
{
std
::
ffi
::
CStr
::
from_ptr
(
arg
.value.v_handle
as
*
const
i8
)
}
.to_str
()
?
)
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
&
'a
str
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
try_downcast!
(
val
->
&
str
,
|
TVMArgValue
::
Str
(
s
)|
{
s
.to_str
()
.unwrap
()
})
}
}
impl
<
'a
>
TryFrom
<&
TVMArgValue
<
'a
>>
for
&
str
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'a
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
TVMTypeCode_kStr
);
Ok
(
unsafe
{
std
::
ffi
::
CStr
::
from_ptr
(
arg
.value.v_handle
as
*
const
i8
)
}
.to_str
()
?
)
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
&
'v
str
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
try_downcast!
(
val
->
&
str
,
|
TVMArgValue
::
Str
(
s
)|
{
s
.to_str
()
.unwrap
()
})
}
}
/// C
reates a conversion to a `TVMArgValue` for an object handl
e.
impl
<
'a
,
T
>
From
<*
const
T
>
for
TVMArgValue
<
'a
>
{
/// C
onverts an unspecialized handle to a TVMArgValu
e.
impl
<
T
>
From
<*
const
T
>
for
TVMArgValue
<
'static
>
{
fn
from
(
ptr
:
*
const
T
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
ptr
as
*
mut
T
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kArrayHandle
as
i64
,
_lifetime
:
PhantomData
,
}
Self
::
Handle
(
ptr
as
*
mut
c_void
)
}
}
/// C
reates a conversion to a `TVMArgValue` for a mutable object handl
e.
impl
<
'a
,
T
>
From
<*
mut
T
>
for
TVMArgValue
<
'a
>
{
/// C
onverts an unspecialized mutable handle to a TVMArgValu
e.
impl
<
T
>
From
<*
mut
T
>
for
TVMArgValue
<
'static
>
{
fn
from
(
ptr
:
*
mut
T
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
ptr
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kHandle
as
i64
,
_lifetime
:
PhantomData
,
}
Self
::
Handle
(
ptr
as
*
mut
c_void
)
}
}
impl
<
'a
>
From
<&
'a
mut
DLTensor
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
'a
mut
DLTensor
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
arr
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kArrayHandle
as
i64
,
_lifetime
:
PhantomData
,
}
Self
::
ArrayHandle
(
arr
as
*
mut
DLTensor
)
}
}
impl
<
'a
>
From
<&
'a
DLTensor
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
'a
DLTensor
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
arr
as
*
const
_
as
*
mut
DLTensor
as
*
mut
c_void
,
},
type_code
:
TVMTypeCode_kArrayHandle
as
i64
,
_lifetime
:
PhantomData
,
}
Self
::
ArrayHandle
(
arr
as
*
const
_
as
*
mut
DLTensor
)
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
TVMType
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
TVMTypeCode_kTVMType
);
Ok
(
unsafe
{
arg
.value.v_type
.into
()
})
impl
TryFrom
<
TVMRetValue
>
for
String
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
String
,
Self
::
Error
>
{
try_downcast!
(
val
->
String
,
|
TVMRetValue
::
String
(
s
)|
{
s
.into_string
()
.unwrap
()
},
|
TVMRetValue
::
Str
(
s
)|
{
s
.to_str
()
.unwrap
()
.to_string
()
}
)
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub
struct
TVMRetValue
{
pub
value
:
TVMValue
,
pub
box_value
:
Box
<
Any
>
,
pub
type_code
:
i64
,
}
impl
TVMRetValue
{
pub
fn
from_tvm_value
(
value
:
TVMValue
,
type_code
:
i64
)
->
Self
{
Self
{
value
,
type_code
,
box_value
:
box
(),
}
}
pub
fn
into_tvm_value
(
self
)
->
(
TVMValue
,
TVMTypeCode
)
{
(
self
.value
,
self
.type_code
as
TVMTypeCode
)
impl
From
<
String
>
for
TVMRetValue
{
fn
from
(
s
:
String
)
->
Self
{
Self
::
String
(
std
::
ffi
::
CString
::
new
(
s
)
.unwrap
())
}
}
impl
Default
for
TVMRetValue
{
fn
default
()
->
Self
{
TVMRetValue
{
value
:
TVMValue
{
v_int64
:
0
as
i64
},
type_code
:
0
,
box_value
:
box
(),
}
impl
From
<
TVMByteArray
>
for
TVMRetValue
{
fn
from
(
arr
:
TVMByteArray
)
->
Self
{
Self
::
Bytes
(
arr
)
}
}
macro_rules!
impl_pod_ret_value
{
(
$code:expr
,
[
$
(
$ty:ty
),
+
]
)
=>
{
$
(
impl
From
<
$ty
>
for
TVMRetValue
{
fn
from
(
val
:
$ty
)
->
Self
{
Self
{
value
:
val
.into
(),
type_code
:
$code
as
i64
,
box_value
:
box
(),
}
}
}
impl
TryFrom
<
TVMRetValue
>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
ret
,
$code
);
Ok
(
ret
.value
.into
())
}
}
)
+
};
}
impl_pod_ret_value!
(
DLDataTypeCode_kDLInt
,
[
i8
,
i16
,
i32
,
i64
,
isize
]);
impl_pod_ret_value!
(
DLDataTypeCode_kDLUInt
,
[
u8
,
u16
,
u32
,
u64
,
usize
]);
impl_pod_ret_value!
(
DLDataTypeCode_kDLFloat
,
[
f32
,
f64
]);
impl_pod_ret_value!
(
TVMTypeCode_kTVMType
,
[
TVMType
]);
impl_pod_ret_value!
(
TVMTypeCode_kTVMContext
,
[
TVMContext
]);
impl
TryFrom
<
TVMRetValue
>
for
String
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
String
,
Self
::
Error
>
{
ensure_type!
(
ret
,
TVMTypeCode_kStr
);
let
cs
=
unsafe
{
std
::
ffi
::
CString
::
from_raw
(
ret
.value.v_handle
as
*
mut
i8
)
};
let
ret_str
=
cs
.clone
()
.into_string
();
if
cfg!
(
feature
=
"bindings"
)
{
std
::
mem
::
forget
(
cs
);
// TVM C++ takes ownership of CString. (@see TVMFuncCall)
}
Ok
(
ret_str
?
)
impl
TryFrom
<
TVMRetValue
>
for
TVMByteArray
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
Self
,
Self
::
Error
>
{
try_downcast!
(
val
->
TVMByteArray
,
|
TVMRetValue
::
Bytes
(
val
)|
{
val
})
}
}
impl
From
<
String
>
for
TVMRetValue
{
fn
from
(
s
:
String
)
->
Self
{
let
cs
=
std
::
ffi
::
CString
::
new
(
s
)
.unwrap
();
Self
{
value
:
TVMValue
{
v_str
:
cs
.into_raw
()
as
*
mut
i8
,
},
box_value
:
box
(),
type_code
:
TVMTypeCode_kStr
as
i64
,
}
impl
Default
for
TVMRetValue
{
fn
default
()
->
Self
{
Self
::
Int
(
0
)
}
}
rust/common/src/value.rs
View file @
14a0ecba
...
...
@@ -137,3 +137,18 @@ impl_tvm_context!(
DLDeviceType_kDLROCM
:
[
rocm
],
DLDeviceType_kDLExtDev
:
[
ext_dev
]
);
impl
TVMByteArray
{
pub
fn
data
(
&
self
)
->
&
'static
[
u8
]
{
unsafe
{
std
::
slice
::
from_raw_parts
(
self
.data
as
*
const
u8
,
self
.size
)
}
}
}
impl
<
'a
>
From
<&
'a
[
u8
]
>
for
TVMByteArray
{
fn
from
(
bytes
:
&
[
u8
])
->
Self
{
Self
{
data
:
bytes
.as_ptr
()
as
*
const
i8
,
size
:
bytes
.len
(),
}
}
}
rust/frontend/Cargo.toml
View file @
14a0ecba
...
...
@@ -9,6 +9,7 @@ readme = "README.md"
keywords
=
[
"rust"
,
"tvm"
,
"nnvm"
]
categories
=
[
"api-bindings"
,
"science"
]
authors
=
[
"TVM Contributors"
]
edition
=
"2018"
[lib]
name
=
"tvm_frontend"
...
...
rust/frontend/src/bytearray.rs
View file @
14a0ecba
...
...
@@ -3,9 +3,9 @@
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use
std
::
os
::
raw
::
{
c_char
,
c_void
}
;
use
std
::
os
::
raw
::
c_char
;
use
tvm_common
::
{
ffi
,
TVMArgValue
}
;
use
tvm_common
::
ffi
;
/// A struct holding TVM byte-array.
///
...
...
@@ -44,8 +44,9 @@ impl TVMByteArray {
}
}
impl
<
'a
>
From
<&
'a
Vec
<
u8
>>
for
TVMByteArray
{
fn
from
(
arg
:
&
Vec
<
u8
>
)
->
Self
{
impl
<
'a
,
T
:
AsRef
<
[
u8
]
>>
From
<
T
>
for
TVMByteArray
{
fn
from
(
arg
:
T
)
->
Self
{
let
arg
=
arg
.as_ref
();
let
barr
=
ffi
::
TVMByteArray
{
data
:
arg
.as_ptr
()
as
*
const
c_char
,
size
:
arg
.len
(),
...
...
@@ -54,18 +55,6 @@ impl<'a> From<&'a Vec<u8>> for TVMByteArray {
}
}
impl
<
'a
>
From
<&
TVMByteArray
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arr
:
&
TVMByteArray
)
->
Self
{
Self
{
value
:
ffi
::
TVMValue
{
v_handle
:
&
arr
.inner
as
*
const
ffi
::
TVMByteArray
as
*
const
c_void
as
*
mut
c_void
,
},
type_code
:
ffi
::
TVMTypeCode_kBytes
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
...
...
rust/frontend/src/context.rs
View file @
14a0ecba
...
...
@@ -26,10 +26,7 @@ use std::{
use
failure
::
Error
;
use
tvm_common
::{
ffi
::{
self
,
TVMValue
},
TVMArgValue
,
};
use
tvm_common
::
ffi
;
use
crate
::
function
;
...
...
@@ -125,18 +122,6 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
impl
<
'a
>
From
<&
'a
TVMDeviceType
>
for
TVMArgValue
<
'a
>
{
fn
from
(
dev_type
:
&
'a
TVMDeviceType
)
->
Self
{
Self
{
value
:
TVMValue
{
v_int64
:
dev_type
.
0
as
i64
,
},
type_code
:
ffi
::
DLDataTypeCode_kDLInt
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
...
...
@@ -209,7 +194,7 @@ impl TVMContext {
let
dt
=
self
.device_type
.
0
as
usize
;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let
ret
:
u64
=
call_packed!
(
func
,
&
dt
,
&
self
.device_id
,
&
0
)
let
ret
:
u64
=
call_packed!
(
func
,
dt
,
self
.device_id
,
0
)
.unwrap
()
.try_into
()
.unwrap
();
...
...
@@ -238,7 +223,9 @@ macro_rules! impl_device_attrs {
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function
::
Builder
::
from
(
func
)
.args
(
&
[
dt
,
self
.device_id
as
usize
,
$attr_kind
])
.arg
(
dt
)
.arg
(
self
.device_id
as
usize
)
.arg
(
$attr_kind
)
.invoke
()
.unwrap
()
.try_into
()
...
...
rust/frontend/src/function.rs
View file @
14a0ecba
...
...
@@ -156,9 +156,9 @@ impl<'a, 'm> Builder<'a, 'm> {
}
/// Pushes a [`TVMArgValue`] into the function argument buffer.
pub
fn
arg
<
T
:
'a
>
(
&
mut
self
,
arg
:
&
'a
T
)
->
&
mut
Self
pub
fn
arg
<
T
:
'a
>
(
&
mut
self
,
arg
:
T
)
->
&
mut
Self
where
TVMArgValue
<
'a
>
:
From
<
&
'a
T
>
,
TVMArgValue
<
'a
>
:
From
<
T
>
,
{
self
.arg_buf
.push
(
arg
.into
());
self
...
...
@@ -192,14 +192,11 @@ impl<'a, 'm> Builder<'a, 'm> {
ensure!
(
self
.func
.is_some
(),
errors
::
FunctionNotFoundError
);
let
num_args
=
self
.arg_buf
.len
();
let
(
mut
values
,
mut
type_codes
):
(
Vec
<
ffi
::
TVMValue
>
,
Vec
<
ffi
::
TVMTypeCode
>
)
=
self
.arg_buf
.iter
()
.map
(|
tvm_arg
|
(
tvm_arg
.value
,
tvm_arg
.type_code
as
ffi
::
TVMTypeCode
))
.unzip
();
let
(
mut
values
,
mut
type_codes
):
(
Vec
<
ffi
::
TVMValue
>
,
Vec
<
ffi
::
TVMTypeCode
>
)
=
self
.arg_buf
.iter
()
.map
(|
arg
|
arg
.to_tvm_value
())
.unzip
();
let
mut
ret_val
=
unsafe
{
std
::
mem
::
uninitialized
::
<
TVMValue
>
()
};
let
mut
ret_type_code
=
0
;
let
mut
ret_type_code
=
0
i32
;
check_call!
(
ffi
::
TVMFuncCall
(
self
.func
.ok_or
(
errors
::
FunctionNotFoundError
)
?
.handle
,
values
.as_mut_ptr
(),
...
...
@@ -209,7 +206,7 @@ impl<'a, 'm> Builder<'a, 'm> {
&
mut
ret_type_code
as
*
mut
_
));
Ok
(
unsafe
{
TVMRetValue
::
from_tvm_value
(
ret_val
,
ret_type_code
as
i64
)
})
Ok
(
unsafe
{
TVMRetValue
::
from_tvm_value
(
ret_val
,
ret_type_code
as
u32
)
})
}
}
...
...
@@ -254,7 +251,7 @@ unsafe extern "C" fn tvm_callback(
{
check_call!
(
ffi
::
TVMCbArgToReturn
(
&
mut
value
as
*
mut
_
,
tcode
));
}
local_args
.push
(
TVMArgValue
::
new
(
value
.into
(),
(
tcode
as
i64
)
.into
()
));
local_args
.push
(
TVMArgValue
::
from_tvm_value
(
value
.into
(),
tcode
as
u32
));
}
let
rv
=
match
rust_fn
(
local_args
.as_slice
())
{
...
...
@@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback(
}
};
let
(
mut
ret_val
,
ret_tcode
)
=
rv
.
in
to_tvm_value
();
let
(
mut
ret_val
,
ret_tcode
)
=
rv
.to_tvm_value
();
let
mut
ret_type_code
=
ret_tcode
as
c_int
;
check_call!
(
ffi
::
TVMCFuncSetReturn
(
ret
,
...
...
@@ -437,8 +434,9 @@ mod tests {
let
str_arg
=
CString
::
new
(
"test"
)
.unwrap
();
let
mut
func
=
Builder
::
default
();
func
.get_function
(
"tvm.graph_runtime.remote_create"
)
.args
(
&
[
10
,
20
])
.arg
(
&
str_arg
);
.arg
(
10
)
.arg
(
20
)
.arg
(
str_arg
.as_c_str
());
assert_eq!
(
func
.arg_buf
.len
(),
3
);
}
}
rust/frontend/src/module.rs
View file @
14a0ecba
...
...
@@ -80,7 +80,7 @@ impl Module {
CString
::
new
(
path
.as_ref
()
.to_str
()
.ok_or_else
(||
{
format_err!
(
"Bad module load path: `{}`."
,
path
.as_ref
()
.display
())
})
?
)
?
;
let
ret
:
Module
=
call_packed!
(
func
,
&
cpath
,
&
ext
)
?
.try_into
()
?
;
let
ret
:
Module
=
call_packed!
(
func
,
cpath
.as_c_str
(),
ext
.as_c_str
()
)
?
.try_into
()
?
;
Ok
(
ret
)
}
...
...
@@ -90,7 +90,10 @@ impl Module {
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let
tgt
=
CString
::
new
(
target
)
.unwrap
();
let
ret
:
i64
=
call_packed!
(
func
,
&
tgt
)
.unwrap
()
.try_into
()
.unwrap
();
let
ret
:
i64
=
call_packed!
(
func
,
tgt
.as_c_str
())
.unwrap
()
.try_into
()
.unwrap
();
ret
!=
0
}
...
...
rust/frontend/src/ndarray.rs
View file @
14a0ecba
...
...
@@ -161,7 +161,7 @@ impl NDArray {
/// Converts the NDArray to [`TVMByteArray`].
pub
fn
to_bytearray
(
&
self
)
->
Result
<
TVMByteArray
,
Error
>
{
let
v
=
self
.to_vec
::
<
u8
>
()
?
;
Ok
(
TVMByteArray
::
from
(
&
v
))
Ok
(
TVMByteArray
::
from
(
v
))
}
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
...
...
rust/frontend/src/value.rs
View file @
14a0ecba
...
...
@@ -2,140 +2,80 @@
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use
std
::
{
convert
::
TryFrom
,
os
::
raw
::
c_void
}
;
use
std
::
convert
::
TryFrom
;
use
failure
::
Error
;
use
tvm_common
::{
ensure_type
,
ffi
::{
self
,
TVMValue
},
errors
::
ValueDowncastError
,
ffi
::{
TVMArrayHandle
,
TVMFunctionHandle
,
TVMModuleHandle
},
try_downcast
,
};
use
crate
::{
common_errors
::
*
,
context
::
TVMContext
,
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMByteArray
,
TVMRetValue
,
};
use
crate
::{
Function
,
Module
,
NDArray
,
TVMArgValue
,
TVMRetValue
};
macro_rules!
impl_tvm_val_from_handle
{
(
$ty:ident
,
$type_code:expr
,
$handle:ty
)
=>
{
impl
<
'a
>
From
<&
'a
$ty
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
$ty
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
arg
.handle
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
macro_rules!
impl_handle_val
{
(
$type:ty
,
$variant:ident
,
$inner_type:ty
,
$ctor:path
)
=>
{
impl
<
'a
>
From
<&
'a
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
$type
)
->
Self
{
TVMArgValue
::
$variant
(
arg
.handle
()
as
$inner_type
)
}
}
impl
<
'a
>
From
<&
'a
mut
$ty
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
mut
$ty
)
->
Self
{
TVMArgValue
{
value
:
TVMValue
{
v_handle
:
arg
.handle
as
*
mut
_
as
*
mut
c_void
,
},
type_code
:
$type_code
as
i64
,
_lifetime
:
std
::
marker
::
PhantomData
,
}
impl
<
'a
>
From
<&
'a
mut
$type
>
for
TVMArgValue
<
'a
>
{
fn
from
(
arg
:
&
'a
mut
$type
)
->
Self
{
TVMArgValue
::
$variant
(
arg
.handle
()
as
$inner_type
)
}
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'v
>
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
arg
,
$type_code
);
Ok
(
$ty
::
new
(
unsafe
{
arg
.value.v_handle
as
$handle
}))
impl
<
'a
>
TryFrom
<
TVMArgValue
<
'a
>>
for
$type
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMArgValue
<
'a
>
)
->
Result
<
$type
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
$ctor
(
val
)
})
}
}
impl
From
<
$ty
>
for
TVMRetValue
{
fn
from
(
val
:
$ty
)
->
TVMRetValue
{
TVMRetValue
{
value
:
TVMValue
{
v_handle
:
val
.handle
()
as
*
mut
c_void
,
},
box_value
:
box
val
,
type_code
:
$type_code
as
i64
,
}
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
$type
{
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
&
'a
TVMArgValue
<
'v
>
)
->
Result
<
$type
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMArgValue
::
$variant
(
val
)|
{
$ctor
(
*
val
)
})
}
}
impl
TryFrom
<
TVMRetValue
>
for
$ty
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$ty
,
Self
::
Error
>
{
ensure_type!
(
ret
,
$type_code
);
Ok
(
$ty
::
new
(
unsafe
{
ret
.value.v_handle
as
$handle
}))
}
}
};
}
impl_tvm_val_from_handle!
(
Function
,
ffi
::
TVMTypeCode_kFuncHandle
,
ffi
::
TVMFunctionHandle
);
impl_tvm_val_from_handle!
(
Module
,
ffi
::
TVMTypeCode_kModuleHandle
,
ffi
::
TVMModuleHandle
);
impl_tvm_val_from_handle!
(
NDArray
,
ffi
::
TVMTypeCode_kArrayHandle
,
ffi
::
TVMArrayHandle
);
impl
<
'a
>
From
<&
'a
TVMByteArray
>
for
TVMValue
{
fn
from
(
barr
:
&
TVMByteArray
)
->
Self
{
TVMValue
{
v_handle
:
&
barr
.inner
as
*
const
ffi
::
TVMByteArray
as
*
mut
c_void
,
}
}
}
macro_rules!
impl_boxed_ret_value
{
(
$type:ty
,
$code:expr
)
=>
{
impl
From
<
$type
>
for
TVMRetValue
{
fn
from
(
val
:
$type
)
->
Self
{
TVMRetValue
{
value
:
TVMValue
{
v_int64
:
0
},
box_value
:
box
val
,
type_code
:
$code
as
i64
,
}
fn
from
(
val
:
$type
)
->
TVMRetValue
{
TVMRetValue
::
$variant
(
val
.handle
()
as
$inner_type
)
}
}
impl
TryFrom
<
TVMRetValue
>
for
$type
{
type
Error
=
Error
;
fn
try_from
(
ret
:
TVMRetValue
)
->
Result
<
$type
,
Self
::
Error
>
{
if
let
Ok
(
val
)
=
ret
.box_value.downcast
::
<
$type
>
()
{
Ok
(
*
val
)
}
else
{
bail!
(
ValueDowncastError
::
new
(
$code
as
i64
,
ret
.type_code
as
i64
))
}
type
Error
=
ValueDowncastError
;
fn
try_from
(
val
:
TVMRetValue
)
->
Result
<
$type
,
Self
::
Error
>
{
try_downcast!
(
val
->
$type
,
|
TVMRetValue
::
$variant
(
val
)|
{
$ctor
(
val
)
})
}
}
};
}
impl_boxed_ret_value!
(
TVMContext
,
ffi
::
TVMTypeCode_kTVMContext
);
impl_boxed_ret_value!
(
TVMByteArray
,
ffi
::
TVMTypeCode_kBytes
);
impl
<
'a
,
'v
>
TryFrom
<&
'a
TVMArgValue
<
'v
>>
for
TVMByteArray
{
type
Error
=
Error
;
fn
try_from
(
arg
:
&
TVMArgValue
<
'v
>
)
->
Result
<
Self
,
Self
::
Error
>
{
ensure_type!
(
arg
,
ffi
::
TVMTypeCode_kBytes
);
Ok
(
TVMByteArray
::
new
(
unsafe
{
*
(
arg
.value.v_handle
as
*
mut
ffi
::
TVMByteArray
)
}))
}
}
impl_handle_val!
(
Function
,
FuncHandle
,
TVMFunctionHandle
,
Function
::
new
);
impl_handle_val!
(
Module
,
ModuleHandle
,
TVMModuleHandle
,
Module
::
new
);
impl_handle_val!
(
NDArray
,
ArrayHandle
,
TVMArrayHandle
,
NDArray
::
new
);
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
std
::{
convert
::
TryInto
,
str
::
FromStr
};
use
tvm_common
::
ffi
::
TVMType
;
use
tvm_common
::{
TVMByteArray
,
TVMContext
,
TVMType
};
use
super
::
*
;
#[test]
fn
bytearray
()
{
let
w
=
vec!
[
1u8
,
2
,
3
,
4
,
5
];
let
v
=
TVMByteArray
::
from
(
&
w
);
let
v
=
TVMByteArray
::
from
(
w
.as_slice
()
);
let
tvm
:
TVMByteArray
=
TVMRetValue
::
from
(
v
)
.try_into
()
.unwrap
();
assert_eq!
(
tvm
.data
(),
w
.iter
()
.map
(|
e
|
*
e
as
i8
)
.collect
::
<
Vec
<
i8
>>
());
assert_eq!
(
tvm
.data
(),
w
.iter
()
.map
(|
e
|
*
e
)
.collect
::
<
Vec
<
u8
>>
()
.as_slice
()
);
}
#[test]
...
...
@@ -147,7 +87,7 @@ mod tests {
#[test]
fn
ctx
()
{
let
c
=
TVMContext
::
from
(
"gpu"
);
let
c
=
TVMContext
::
from
_str
(
"gpu"
)
.unwrap
(
);
let
tvm
:
TVMContext
=
TVMRetValue
::
from
(
c
)
.try_into
()
.unwrap
();
assert_eq!
(
tvm
,
c
);
}
...
...
rust/frontend/tests/callback/src/bin/string.rs
View file @
14a0ecba
...
...
@@ -24,9 +24,9 @@ fn main() {
registered
.get_function
(
"concate_str"
);
assert
!
(
registered
.func
.is_some
());
let
ret
:
String
=
registered
.arg
(
&
a
)
.arg
(
&
b
)
.arg
(
&
c
)
.arg
(
a
.as_c_str
()
)
.arg
(
b
.as_c_str
()
)
.arg
(
c
.as_c_str
()
)
.invoke
()
.unwrap
()
.try_into
()
...
...
rust/runtime/Cargo.toml
View file @
14a0ecba
...
...
@@ -8,6 +8,7 @@ readme = "README.md"
keywords
=
[
"tvm"
,
"nnvm"
]
categories
=
[
"api-bindings"
,
"science"
]
authors
=
[
"TVM Contributors"
]
edition
=
"2018"
[features]
default
=
["nom/std"]
...
...
rust/runtime/src/graph.rs
View file @
14a0ecba
...
...
@@ -265,7 +265,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
.iter
()
.map
(|
t
|
t
.into
())
.collect
::
<
Vec
<
TVMArgValue
>>
();
func
(
args
.as_slice
()
)
.unwrap
();
func
(
&
args
)
.unwrap
();
};
op_execs
.push
(
op
);
}
...
...
@@ -283,7 +283,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
// TODO: consider `new_with_params` to avoid ever allocating
let
ptr
=
self
.tensors
[
idx
]
.data
.as_ptr
();
let
mut
to_replace
=
self
.tensors
.iter_mut
()
.filter
(|
t
|
t
.data
.as_ptr
()
==
ptr
);
let
mut
owner
=
to_replace
.nth
(
0
)
.unwrap
();
let
owner
=
to_replace
.nth
(
0
)
.unwrap
();
if
value
.data
.is_owned
()
{
// FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
// mem::replace(&mut (*owner), value);
...
...
rust/runtime/src/module.rs
View file @
14a0ecba
...
...
@@ -40,17 +40,14 @@ pub(super) fn wrap_backend_packed_func(
func
:
BackendPackedCFunc
,
)
->
Box
<
dyn
PackedFunc
>
{
box
move
|
args
:
&
[
TVMArgValue
]|
{
let
exit_code
=
func
(
args
.iter
()
.map
(|
ref
arg
|
arg
.value
)
.collect
::
<
Vec
<
TVMValue
>>
()
.as_ptr
(),
args
.iter
()
.map
(|
ref
arg
|
arg
.type_code
as
i32
)
.collect
::
<
Vec
<
i32
>>
()
.as_ptr
()
as
*
const
i32
,
args
.len
()
as
i32
,
);
let
(
values
,
type_codes
):
(
Vec
<
TVMValue
>
,
Vec
<
i32
>
)
=
args
.into_iter
()
.map
(|
arg
|
{
let
(
val
,
code
)
=
arg
.to_tvm_value
();
(
val
,
code
as
i32
)
})
.unzip
();
let
exit_code
=
func
(
values
.as_ptr
(),
type_codes
.as_ptr
(),
values
.len
()
as
i32
);
if
exit_code
==
0
{
Ok
(
TVMRetValue
::
default
())
}
else
{
...
...
rust/runtime/tests/test_graph_serde.rs
View file @
14a0ecba
#
!
[
feature
(
try_from
)]
extern
crate
serde
;
extern
crate
serde_json
;
...
...
rust/runtime/tests/test_nnvm/src/main.rs
View file @
14a0ecba
#
!
[
feature
(
try_from
)]
#
[
macro_use
]
extern
crate
ndarray
;
extern
crate
serde
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment