Commit 5563b72b by nhynes Committed by Tianqi Chen

Add rust runtime (#1597)

parent 6330797d
...@@ -91,10 +91,8 @@ ENV/ ...@@ -91,10 +91,8 @@ ENV/
*~ *~
*.pyc *.pyc
*~ *~
build
config.mk config.mk
config.cmake config.cmake
build_*
Win32 Win32
*.dir *.dir
perf perf
...@@ -187,7 +185,6 @@ tvm_u.* ...@@ -187,7 +185,6 @@ tvm_u.*
tvm_t.* tvm_t.*
# Mac OS X # Mac OS X
.DS_Store .DS_Store
build*
# Jetbrain # Jetbrain
.idea .idea
......
Cargo.lock
target/
**/*.rs.bk
max_width = 100
hard_tabs = false
tab_spaces = 2
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
wrap_comments = false
comment_width = 80
normalize_comments = false
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
empty_item_single_line = true
struct_lit_single_line = true
fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
merge_imports = true
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true
combine_control_expr = true
struct_field_align_threshold = 0
match_arm_blocks = true
force_multiline_blocks = false
fn_args_density = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "Edition2015"
merge_derives = true
use_try_shorthand = true
use_field_init_shorthand = false
force_explicit_abi = true
condense_wildcard_suffixes = false
color = "Auto"
required_version = "0.99.4"
unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Never"
report_fixme = "Never"
ignore = []
emit_mode = "Files"
make_backup = false
language: rust
rust:
- nightly
matrix:
fast_finish: true
[package]
name = "tvm"
version = "0.1.0"
license = "Apache-2.0"
description = "TVM Rust runtime"
repository = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[features]
default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
#[cfg(target_env = "sgx")]
use alloc::alloc;
#[cfg(not(target_env = "sgx"))]
use std::alloc;
use std::num;
use ndarray;
use serde_json;
error_chain! {
errors {
TryFromTVMRetValueError(expected: String, actual: i64) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but was `{}`", expected, actual)
}
GraphFormatError(msg: String) {
description("unable to load graph")
display("could not load graph json: {}", msg)
}
LoadGraphParamsError(msg: String) {
description("unable to load graph params")
display("could not load graph params: {}", msg)
}
}
foreign_links {
Alloc(alloc::AllocErr);
GraphDeserialize(serde_json::Error);
ParseInt(num::ParseIntError);
ShapeError(ndarray::ShapeError);
}
}
impl From<alloc::LayoutErr> for Error {
fn from(_err: alloc::LayoutErr) -> Error {
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
}
}
//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
//! It's mainly useful for compiling to WebAssembly and SGX,
//! but also native if you prefer Rust to C++.
//!
//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
//! Single-function modules are used via the `packed_func!` macro after obtaining
//! the function from `runtime::SystemLibModule`
//!
//! The main entrypoints to this crate are `GraphExecutor`
//! For examples of use, please refer to the multi-file tests in the `tests` directory.
#![feature(
alloc,
allocator_api,
box_syntax,
extern_prelude,
fn_traits,
try_from,
unboxed_closures,
vec_remove_item
)]
#[cfg(target_env = "sgx")]
extern crate alloc;
extern crate bounded_spsc_queue;
#[cfg(target_env = "sgx")]
extern crate core;
#[macro_use]
extern crate error_chain;
#[macro_use]
extern crate itertools;
#[macro_use]
extern crate lazy_static;
extern crate ndarray;
#[macro_use]
extern crate nom;
#[cfg(not(target_env = "sgx"))]
extern crate num_cpus;
extern crate serde;
#[macro_use]
extern crate serde_derive;
extern crate serde_json;
pub mod ffi {
#![allow(
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
unused
)]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/runtime/c_runtime_api.rs"
));
pub type BackendPackedCFunc =
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
}
}
pub mod errors;
pub mod runtime;
pub use errors::*;
#[cfg(target_env = "sgx")]
use alloc::alloc::{self, Layout};
#[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout};
use errors::*;
const DEFAULT_ALIGN_BYTES: usize = 4;
#[derive(PartialEq, Eq)]
pub struct Allocation {
layout: Layout,
ptr: *mut u8,
}
impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment.
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) };
if ptr.is_null() {
alloc::handle_alloc_error(layout);
}
Ok(Self {
ptr: ptr,
layout: layout,
})
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
/// Returns the size of the Allocation in bytes.
pub fn size(&self) -> usize {
self.layout.size()
}
/// Returns the byte alignment of the Allocation.
pub fn align(&self) -> usize {
self.layout.align()
}
}
impl Drop for Allocation {
fn drop(&mut self) {
unsafe {
alloc::dealloc(self.ptr, self.layout.clone());
}
}
}
mod allocator;
mod array;
mod module;
#[macro_use]
mod packed_func;
mod graph;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
use std::os::raw::c_char;
pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
#[cfg(not(target_env = "sgx"))]
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
};
use ffi::runtime::BackendPackedCFunc;
use runtime::packed_func::{wrap_backend_packed_func, PackedFunc};
pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
}
pub struct SystemLibModule;
lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
Mutex::new(HashMap::new());
}
impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.get(name.as_ref())
.map(|func| wrap_backend_packed_func(func.to_owned()))
}
}
impl Default for SystemLibModule {
fn default() -> Self {
SystemLibModule {}
}
}
#[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char,
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS
.lock()
.unwrap()
.insert(name.to_string(), func);
return 0;
}
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use ffi::runtime::{
BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue,
};
use errors::*;
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
_lifetime: PhantomData<&'a ()>,
pub(crate) value: TVMValue,
pub(crate) type_code: i64,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type:ty, $field:ident, $code:expr, $as:ty) => {
impl<'a> From<$type> for TVMArgValue<'a> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $as },
type_code: $code as i64,
_lifetime: PhantomData,
}
}
}
};
($type:ty, $field:ident, $code:expr) => {
impl_prim_tvm_arg!($type, $field, $code, $type);
};
($type:ty,v_int64) => {
impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64);
};
($type:ty,v_float64) => {
impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64);
};
}
impl_prim_tvm_arg!(f32, v_float64);
impl_prim_tvm_arg!(f64, v_float64);
impl_prim_tvm_arg!(i8, v_int64);
impl_prim_tvm_arg!(u8, v_int64);
impl_prim_tvm_arg!(i32, v_int64);
impl_prim_tvm_arg!(u32, v_int64);
impl_prim_tvm_arg!(i64, v_int64);
impl_prim_tvm_arg!(u64, v_int64);
impl_prim_tvm_arg!(bool, v_int64);
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
/// A primitive return value, if any.
prim_value: u64,
/// An object return value, if any.
box_value: Box<Any>,
/// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use.
type_code: i64,
}
#[cfg(target_env = "sgx")]
impl TVMRetValue {
pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
unsafe {
Self {
prim_value: match type_code {
0 | 1 => value.v_int64 as u64,
2 => value.v_float64 as u64,
3 | 7 | 8 | 9 | 10 => value.v_handle as u64,
11 | 12 => value.v_str as u64,
_ => 0,
} as u64,
box_value: box (),
type_code: type_code,
}
}
}
pub fn into_tvm_value(self) -> (TVMValue, i64) {
let val = match self.type_code {
0 | 1 => TVMValue {
v_int64: self.prim_value.clone() as i64,
},
2 => TVMValue {
v_float64: self.prim_value.clone() as f64,
},
3 | 7 | 8 | 9 | 10 => TVMValue {
v_handle: Box::into_raw(self.box_value) as *mut c_void,
},
11 | 12 => TVMValue {
v_str: Box::into_raw(self.box_value) as *const _,
},
_ => unreachable!(),
};
(val, self.type_code)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
prim_value: 0,
box_value: box (),
type_code: 0,
}
}
}
macro_rules! impl_prim_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: val as u64,
box_value: box (),
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == $code {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code
))
}
}
}
};
}
macro_rules! impl_boxed_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box val,
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code
))
}
}
}
};
}
impl_prim_ret_value!(i8, 0);
impl_prim_ret_value!(u8, 1);
impl_prim_ret_value!(i16, 0);
impl_prim_ret_value!(u16, 1);
impl_prim_ret_value!(i32, 0);
impl_prim_ret_value!(u32, 1);
impl_prim_ret_value!(f32, 2);
impl_prim_ret_value!(i64, 0);
impl_prim_ret_value!(u64, 1);
impl_prim_ret_value!(f64, 2);
impl_prim_ret_value!(isize, 0);
impl_prim_ret_value!(usize, 1);
impl_boxed_ret_value!(String, 11);
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args
.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args
.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
use std::{
ffi::CString,
os::raw::{c_char, c_int},
};
use errors::Result;
use ffi::runtime::TVMValue;
use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue};
pub use runtime::threading::tvm_run_worker as run_worker;
#[macro_export]
macro_rules! tvm_ocall {
($func: expr) => {
match $func {
0 => Ok(()),
err => Err(format!("SGX error: {}", err)),
}
};
}
pub type SgxStatus = u32;
#[cfg(target_env = "sgx")]
extern "C" {
fn tvm_ocall_packed_func(
name: *const c_char,
arg_values: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
ret_val: *mut TVMValue,
ret_type_code: *mut c_int,
) -> SgxStatus;
}
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64;
unsafe {
tvm_ocall!(tvm_ocall_packed_func(
CString::new(fn_name.as_ref()).unwrap().as_ptr(),
args
.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args
.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
&mut ret_val as *mut TVMValue,
&mut (ret_type_code as i32) as *mut c_int,
))?;
}
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
}
#[macro_export]
macro_rules! ocall_packed {
($fn_name:expr, $($args:expr),+) => {
::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+])
.expect(concat!("Error calling `", $fn_name, "`"))
};
($fn_name:expr) => {
::runtime::sgx::ocall_packed_func($fn_name, &Vec::new())
.expect(concat!("Error calling `", $fn_name, "`"))
}
}
pub fn shutdown() {
if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads()
}
}
impl Drop for SystemLibModule {
fn drop(&mut self) {
shutdown()
}
}
use std::{
os::raw::{c_int, c_void},
sync::{
atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT},
Arc, Barrier,
},
};
#[cfg(not(target_env = "sgx"))]
use num_cpus;
#[cfg(not(target_env = "sgx"))]
use std::{
env,
thread::{self, JoinHandle},
};
#[cfg(target_env = "sgx")]
use std::{collections::VecDeque, ptr, sync::Mutex};
use bounded_spsc_queue::{self, Producer};
use super::super::errors::*;
use ffi::runtime::TVMParallelGroupEnv;
#[cfg(target_env = "sgx")]
use super::{TVMArgValue, TVMRetValue};
type FTVMParallelLambda =
extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32;
/// Holds a parallel job request made by a TVM library function.
struct Job {
cb: FTVMParallelLambda,
cdata: *const c_void,
req_num_tasks: usize,
pending: Arc<AtomicUsize>,
}
impl Job {
/// Splits this job into a number of `Task`s which can be scheduled.
fn tasks(&self, num_workers: usize) -> Vec<Task> {
let num_tasks = if self.req_num_tasks == 0 {
num_workers
} else {
self.req_num_tasks.min(num_workers)
};
self.pending.store(num_tasks, Ordering::SeqCst);
let barrier = Arc::new(Barrier::new(num_tasks));
(0..num_tasks)
.map(move |i| Task {
id: i,
flambda: self.cb,
penv: TVMParallelGroupEnv {
sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void,
num_task: num_tasks as i32,
},
cdata: self.cdata,
pending: Arc::clone(&self.pending),
}).collect()
}
/// Waits for all tasks in this `Job` to be completed.
fn wait(&self) -> Result<()> {
while self.pending.load(Ordering::Acquire) > 0 {
#[cfg(not(target_env = "sgx"))]
thread::yield_now();
}
Ok(())
}
}
/// A chunk of work requested by a TVM function.
struct Task {
id: usize,
flambda: FTVMParallelLambda,
penv: TVMParallelGroupEnv,
cdata: *const c_void,
pending: Arc<AtomicUsize>,
}
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
impl FnOnce<()> for Task {
type Output = i32;
extern "rust-call" fn call_once(self, _args: ()) -> Self::Output {
let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata);
self.pending.fetch_sub(1, Ordering::AcqRel);
status
}
}
#[derive(Default)]
struct Threads {
#[allow(unused)]
#[cfg(not(target_env = "sgx"))]
handles: Vec<JoinHandle<()>>,
queues: Vec<Producer<Task>>,
}
impl<'a> Threads {
#[cfg(not(target_env = "sgx"))]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize,
cb: F,
) -> Self {
let (handles, queues) = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
let handle = thread::spawn(move || cb(c.into()));
(handle, p)
}).unzip();
Threads {
handles: handles,
queues: queues,
}
}
#[cfg(target_env = "sgx")]
fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>(
num_threads: usize,
_cb: F,
) -> Self {
let mut consumer_queues = SGX_QUEUES.lock().unwrap();
let queues = (0..num_threads)
.map(|_| {
let (p, c) = bounded_spsc_queue::make(2);
consumer_queues.push_back(c.into());
p
}).collect();
ocall_packed!("__sgx_thread_group_launch__", num_threads as u64);
Threads { queues: queues }
}
}
struct ThreadPool {
num_workers: usize,
#[allow(unused)]
threads: Threads,
}
thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new());
impl ThreadPool {
fn new() -> Self {
let num_workers = max_concurrency();
ThreadPool {
num_workers: num_workers,
threads: Threads::launch(num_workers, ThreadPool::run_worker),
}
}
fn launch(&self, job: Job) {
let mut tasks = job.tasks(self.num_workers + 1);
for (i, task) in tasks.split_off(1).into_iter().enumerate() {
self.threads.queues[i].push(task);
}
tasks.pop().unwrap()();
job.wait().unwrap();
}
fn run_worker(queue: Consumer<Task>) {
loop {
let task = queue.pop();
let result = task();
if result == <i32>::min_value() {
break;
} else if result != 0 {
panic!("Error running task.");
}
}
}
}
// Send + Sync wrapper for bounded_spsc_queue::Consumer
struct Consumer<T> {
consumer: bounded_spsc_queue::Consumer<T>,
}
impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> {
fn from(c: bounded_spsc_queue::Consumer<T>) -> Self {
Consumer { consumer: c }
}
}
impl<T> Consumer<T> {
fn pop(&self) -> T {
self.consumer.pop()
}
}
unsafe impl<T> Send for Consumer<T> {}
unsafe impl<T> Sync for Consumer<T> {}
#[cfg(target_env = "sgx")]
lazy_static! {
/// Holds tasks for untrusted threads which re-enter the enclave to execute.
static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new());
}
#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))]
fn max_concurrency() -> usize {
if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) {
if let Ok(threads) = usize::from_str_radix(&threads_str, 10) {
return threads;
}
}
num_cpus::get_physical()
}
#[cfg(target_env = "sgx")]
fn max_concurrency() -> usize {
usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1)
}
#[cfg(target_arch = "wasm32")]
fn max_concurrency() -> usize {
0 // wasm doesn't support threads yet
}
#[cfg(target_env = "sgx")]
pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue {
let q = {
let mut qs = SGX_QUEUES.lock().unwrap();
qs.pop_front()
// `qs: MutexGuard` needs to be dropped here since `run_worker` won't return
};
if let Some(q) = q {
ThreadPool::run_worker(q);
}
TVMRetValue::default()
}
#[no_mangle]
pub extern "C" fn TVMBackendParallelLaunch(
cb: FTVMParallelLambda,
cdata: *const c_void,
num_task: usize,
) -> c_int {
if max_concurrency() == 0 {
let penv = TVMParallelGroupEnv {
sync_handle: 0 as *mut c_void,
num_task: 1,
};
cb(0, &penv as *const _, cdata);
} else {
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: cb,
cdata: cdata,
req_num_tasks: num_task,
pending: Arc::new(ATOMIC_USIZE_INIT),
});
});
}
return 0;
}
#[cfg(target_env = "sgx")]
pub(crate) fn sgx_join_threads() {
extern "C" fn poison_pill(
_task_id: usize,
_penv: *const TVMParallelGroupEnv,
_cdata: *const c_void,
) -> i32 {
<i32>::min_value()
}
THREAD_POOL.with(|pool| {
pool.launch(Job {
cb: poison_pill,
cdata: ptr::null(),
req_num_tasks: 0,
pending: Arc::new(ATOMIC_USIZE_INIT),
});
});
ocall_packed!("__sgx_thread_group_join__", 0);
}
// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used.
#[no_mangle]
pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) {
let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) };
barrier.wait();
}
#[cfg(test)]
mod tests {
use std::{ptr, thread, time::Duration};
use super::*;
#[test]
fn test_max_concurrency() {
env::set_var("TVM_NUM_THREADS", "42");
env::set_var("OMP_NUM_THREADS", "24");
assert_eq!(max_concurrency(), 42);
env::remove_var("TVM_NUM_THREADS");
assert_eq!(max_concurrency(), 24);
}
extern "C" fn flambda(
task_id: usize,
penv: *const TVMParallelGroupEnv,
cdata: *const c_void,
) -> i32 {
if cdata == ptr::null() {
return 0;
}
unsafe {
let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize));
thread::sleep(Duration::from_millis(50 * task_id as u64));
counter.fetch_add(1, Ordering::SeqCst);
task_ids_sum.fetch_add(task_id, Ordering::SeqCst);
assert_eq!((*penv).num_task, 3);
}
0
}
#[test]
fn test_parallel_launch() {
TVMBackendParallelLaunch(flambda, ptr::null(), 6);
let counter = ATOMIC_USIZE_INIT;
let task_ids_sum = ATOMIC_USIZE_INIT;
let cdata = (counter, task_ids_sum);
let num_tasks = 3;
TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks);
assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks);
assert_eq!(
cdata.1.load(Ordering::SeqCst),
(0..num_tasks).sum::<usize>()
);
}
}
use std::{
cell::RefCell,
os::raw::{c_int, c_void},
ptr,
};
use super::allocator::Allocation;
use errors::*;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
struct WorkspacePool {
workspaces: Vec<Allocation>,
free: Vec<usize>,
in_use: Vec<usize>,
}
impl WorkspacePool {
fn new() -> Self {
WorkspacePool {
workspaces: Vec::new(),
free: Vec::new(),
in_use: Vec::new(),
}
}
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
if self.free.len() == 0 {
return self.alloc_new(size);
}
let idx = self
.free
.iter()
.fold(None, |cur_ws_idx: Option<usize>, &idx| {
let ws_size = self.workspaces[idx].size();
if !ws_size >= size {
return cur_ws_idx;
}
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size {
true => idx,
false => cur_idx,
})
})
});
match idx {
Some(idx) => {
self.free.remove_item(&idx).unwrap();
self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr())
}
None => self.alloc_new(size),
}
}
fn free(&mut self, ptr: *mut u8) -> Result<()> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
if self.workspaces[idx].as_mut_ptr() == ptr {
self.in_use.remove(i);
ws_idx = Some(idx);
break;
}
}
Ok(
self
.free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?),
)
}
}
thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
#[no_mangle]
pub extern "C" fn TVMBackendAllocWorkspace(
_device_type: c_int,
_device_id: c_int,
size: u64,
_dtype_code_hint: c_int,
_dtype_bits_hint: c_int,
) -> *mut c_void {
let nbytes = if size == 0 {
WORKSPACE_PAGE_SIZE
} else {
size as usize
};
WORKSPACE_POOL.with(|pool_cell| {
pool_cell
.borrow_mut()
.alloc(nbytes as usize)
.unwrap_or(ptr::null_mut()) as *mut c_void
})
}
#[no_mangle]
pub extern "C" fn TVMBackendFreeWorkspace(
_device_type: c_int,
_device_id: c_int,
ptr: *mut c_void,
) -> c_int {
WORKSPACE_POOL.with(|pool_cell| {
(match pool_cell.borrow_mut().free(ptr as *mut u8) {
Ok(()) => 0,
Err(_) => -1,
}) as c_int
});
return 0;
}
*.json
*.params
*.o
"""Builds a simple NNVM graph for testing."""
from os import path as osp
import nnvm
from nnvm import sym
from nnvm.compiler import graph_util
from nnvm.testing import init
import numpy as np
import tvm
CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
def _get_model(dshape):
data = sym.Variable('data', shape=dshape)
fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True)
left, right = sym.split(fc1, indices_or_sections=2, axis=1)
return sym.Group(((left + 1), (right - 1)))
def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
if isinstance(graph, sym.Symbol):
graph = nnvm.graph.create(graph)
ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
param_shapes = dict(zip(graph.index.input_names, ishapes))
np.random.seed(seed)
params = {}
for param, shape in param_shapes.items():
if param in {'data', 'label'} or not shape:
continue
init_value = np.empty(shape).astype('float32')
initializer(param, init_value)
params[param] = tvm.nd.array(init_value)
return params
def main():
dshape = (32, 16)
net = _get_model(dshape)
ishape_dict = {'data': dshape}
params = _init_params(net, ishape_dict)
graph, lib, params = nnvm.compiler.build(net, 'llvm',
shape=ishape_dict,
params=params,
dtype='float32')
with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet:
f_resnet.write(graph.json())
with open(osp.join(CWD, 'graph.params'), 'wb') as f_params:
f_params.write(nnvm.compiler.save_param_dict(params))
if __name__ == '__main__':
main()
#![feature(try_from)]
extern crate serde;
extern crate serde_json;
extern crate tvm;
use std::{convert::TryFrom, fs, io::Read};
use tvm::runtime::Graph;
#[test]
fn test_load_graph() {
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
.read_to_end(&mut params_bytes)
.unwrap();
let _params = tvm::runtime::load_param_dict(&params_bytes);
let graph = Graph::try_from(
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
).unwrap();
assert_eq!(graph.nodes[3].op, "tvm_op");
assert_eq!(
graph.nodes[3]
.attrs
.as_ref()
.unwrap()
.get("func_name")
.unwrap(),
"fuse_dense"
);
assert_eq!(graph.nodes[5].inputs[0].index, 0);
assert_eq!(graph.nodes[6].inputs[0].index, 1);
assert_eq!(graph.heads.len(), 2);
}
[package]
name = "test-nnvm"
version = "0.0.0"
license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[dependencies]
ndarray = "0.11.2"
tvm = { path = "../../" }
serde = "1.0.59"
serde_json = "1.0.17"
[build-dependencies]
ar = "0.6.0"
extern crate ar;
use std::{env, path::PathBuf, process::Command};
use ar::Builder;
use std::fs::File;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_graph.py"
)).arg(&out_dir)
.output()
.expect("Failed to execute command");
if output.stderr.len() > 0 {
panic!(String::from_utf8(output.stderr).unwrap());
}
let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect();
let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect();
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
builder.append_path(in_path.to_str().unwrap()).unwrap();
println!("cargo:rustc-link-lib=static=graph");
println!("cargo:rustc-link-search=native={}", out_dir);
}
#!/usr/bin/env python3
"""Builds a simple NNVM graph for testing."""
from os import path as osp
import sys
import nnvm
from nnvm import sym
from nnvm.compiler import graph_util
from nnvm.testing import init
import numpy as np
import tvm
def _get_model(dshape):
data = sym.Variable('data', shape=dshape)
fc = sym.dense(data, units=dshape[-1]*2, use_bias=True)
left, right = sym.split(fc, indices_or_sections=2, axis=1)
return sym.Group(((left + 1), (right - 1), fc))
def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
if isinstance(graph, sym.Symbol):
graph = nnvm.graph.create(graph)
ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
param_shapes = dict(zip(graph.index.input_names, ishapes))
np.random.seed(seed)
params = {}
for param, shape in param_shapes.items():
if param in {'data', 'label'} or not shape:
continue
init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32')
if param.endswith('_bias'):
params[param] = tvm.nd.array(init_value)
continue
init_value = np.empty(shape).astype('float32')
initializer(param, init_value)
# init_value /= init_value.sum() + 1e-10
params[param] = tvm.nd.array(init_value)
return params
def main():
dshape = (4, 8)
net = _get_model(dshape)
ishape_dict = {'data': dshape}
params = _init_params(net, ishape_dict)
graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib',
shape=ishape_dict,
params=params,
dtype='float32')
out_dir = sys.argv[1]
lib.save(osp.join(sys.argv[1], 'graph.o'))
with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
f_resnet.write(graph.json())
with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
f_params.write(nnvm.compiler.save_param_dict(params))
if __name__ == '__main__':
main()
#![feature(try_from)]
#[macro_use]
extern crate ndarray;
extern crate serde;
extern crate serde_json;
extern crate tvm;
use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
use ndarray::Array;
use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
const BATCH_SIZE: usize = 4;
const IN_DIM: usize = 8;
macro_rules! check_sum {
($e:expr, $a:ident, $b:ident) => {
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
check_sum!(a, $b);
};
($e:expr, $a:expr, $b:ident) => {
let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
check_sum!(a, $b);
};
($a:ident, $b:ident) => {
let a_sum: f32 = $a.scalar_sum();
let b_sum: f32 = $b.scalar_sum();
assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
};
}
fn main() {
let syslib = SystemLibModule::default();
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
.unwrap()
.read_to_end(&mut params_bytes)
.unwrap();
let params = tvm::runtime::load_param_dict(&params_bytes)
.unwrap()
.into_iter()
.map(|(k, v)| (k, v.to_owned()))
.collect::<HashMap<String, Tensor<'static>>>();
let graph =
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap();
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
let x = Array::from_shape_vec(
(BATCH_SIZE, IN_DIM),
(0..BATCH_SIZE * IN_DIM)
.map(|x| x as f32)
.collect::<Vec<f32>>(),
).unwrap();
let w = Array::try_from(params.get("dense0_weight").unwrap())
.unwrap()
.into_shape((IN_DIM * 2, IN_DIM))
.unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]);
let expected_o0 = &left + 1f32;
let expected_o1 = &right - 1f32;
exec.load_params(params);
exec.set_input("data", x.clone().into());
check_sum!(exec, data, x);
check_sum!(exec, dense0_weight, w);
check_sum!(exec, dense0_bias, b);
exec.run();
check_sum!(exec, 0, expected_o0);
check_sum!(exec, 1, expected_o1);
check_sum!(exec, 2, dense);
}
[package]
name = "test-tvm-basic"
version = "0.0.0"
license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"]
[dependencies]
ndarray = "0.11.2"
tvm = { path = "../../" }
[build-dependencies]
ar = "0.6.0"
extern crate ar;
use std::{env, path::PathBuf, process::Command};
use ar::Builder;
use std::fs::File;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_lib.py"
)).arg(&out_dir)
.output()
.expect("Failed to execute command");
if output.stderr.len() > 0 {
panic!(String::from_utf8(output.stderr).unwrap());
}
let in_path: PathBuf = [&out_dir, "test.o"].iter().collect();
let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect();
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
builder.append_path(in_path.to_str().unwrap()).unwrap();
println!("cargo:rustc-link-lib=static=test");
println!("cargo:rustc-link-search=native={}", out_dir);
}
#!/usr/bin/env python3
"""Prepares a simple TVM library for testing."""
from os import path as osp
import sys
import tvm
def main():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o'))
if __name__ == '__main__':
main()
extern crate ndarray;
#[macro_use]
extern crate tvm;
use ndarray::Array;
use tvm::{
ffi::runtime::DLTensor,
runtime::{Module, SystemLibModule},
};
fn main() {
let syslib = SystemLibModule::default();
let add = syslib
.get_function("default_function")
.expect("main function not found");
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]);
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
assert!(c.all_close(&e, 1e-8f32));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment