/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ use std::{ os::raw::{c_int, c_void}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, Barrier, }, }; #[cfg(not(target_env = "sgx"))] use num_cpus; #[cfg(not(target_env = "sgx"))] use std::{ env, thread::{self, JoinHandle}, }; #[cfg(target_env = "sgx")] use std::{collections::VecDeque, ptr, sync::Mutex}; use bounded_spsc_queue::{self, Producer}; use tvm_common::ffi::TVMParallelGroupEnv; #[cfg(target_env = "sgx")] use super::{TVMArgValue, TVMRetValue}; pub(crate) type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; /// Holds a parallel job request made by a TVM library function. struct Job { cb: FTVMParallelLambda, cdata: *const c_void, req_num_tasks: usize, pending: Arc<AtomicUsize>, } impl Job { /// Splits this job into a number of `Task`s which can be scheduled. fn tasks(&self, num_workers: usize) -> Vec<Task> { let num_tasks = if self.req_num_tasks == 0 { num_workers } else { self.req_num_tasks.min(num_workers) }; self.pending.store(num_tasks, Ordering::SeqCst); let barrier = Arc::new(Barrier::new(num_tasks)); (0..num_tasks) .map(move |i| Task { id: i, flambda: self.cb, penv: TVMParallelGroupEnv { sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, num_task: num_tasks as i32, }, cdata: self.cdata, pending: Arc::clone(&self.pending), }) .collect() } /// Waits for all tasks in this `Job` to be completed. fn wait(&self) { while self.pending.load(Ordering::Acquire) > 0 { #[cfg(not(target_env = "sgx"))] thread::yield_now(); } } } /// A chunk of work requested by a TVM function. struct Task { id: usize, flambda: FTVMParallelLambda, penv: TVMParallelGroupEnv, cdata: *const c_void, pending: Arc<AtomicUsize>, } unsafe impl Send for Task {} unsafe impl Sync for Task {} impl FnOnce<()> for Task { type Output = i32; extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); self.pending.fetch_sub(1, Ordering::AcqRel); status } } #[derive(Default)] struct Threads { #[allow(unused)] #[cfg(not(target_env = "sgx"))] handles: Vec<JoinHandle<()>>, queues: Vec<Producer<Task>>, } impl<'a> Threads { #[cfg(not(target_env = "sgx"))] fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>( num_threads: usize, cb: F, ) -> Self { let (handles, queues) = (0..num_threads) .map(|_| { let (p, c) = bounded_spsc_queue::make(2); let handle = thread::spawn(move || cb(c.into())); (handle, p) }) .unzip(); Threads { handles: handles, queues: queues, } } #[cfg(target_env = "sgx")] fn launch<F: Sync + Send + FnOnce(Consumer<Task>) + 'static + Copy>( num_threads: usize, _cb: F, ) -> Self { let mut consumer_queues = SGX_QUEUES.lock().unwrap(); let queues = (0..num_threads) .map(|_| { let (p, c) = bounded_spsc_queue::make(2); consumer_queues.push_back(c.into()); p }) .collect(); ocall_packed!("__sgx_thread_group_launch__", num_threads as u64); Threads { queues: queues } } } struct ThreadPool { num_workers: usize, #[allow(unused)] threads: Threads, } thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); impl ThreadPool { fn new() -> Self { let num_workers = max_concurrency(); ThreadPool { num_workers: num_workers, threads: Threads::launch(num_workers, ThreadPool::run_worker), } } fn launch(&self, job: Job) { let mut tasks = job.tasks(self.num_workers + 1); for (i, task) in tasks.split_off(1).into_iter().enumerate() { self.threads.queues[i].push(task); } tasks.pop().unwrap()(); job.wait(); } fn run_worker(queue: Consumer<Task>) { loop { let task = queue.pop(); let result = task(); if result == <i32>::min_value() { break; } else if result != 0 { panic!("Error running task."); } } } } // Send + Sync wrapper for bounded_spsc_queue::Consumer struct Consumer<T> { consumer: bounded_spsc_queue::Consumer<T>, } impl<T> From<bounded_spsc_queue::Consumer<T>> for Consumer<T> { fn from(c: bounded_spsc_queue::Consumer<T>) -> Self { Consumer { consumer: c } } } impl<T> Consumer<T> { fn pop(&self) -> T { self.consumer.pop() } } unsafe impl<T> Send for Consumer<T> {} unsafe impl<T> Sync for Consumer<T> {} #[cfg(target_env = "sgx")] lazy_static! { /// Holds tasks for untrusted threads which re-enter the enclave to execute. static ref SGX_QUEUES: Mutex<VecDeque<Consumer<Task>>> = Mutex::new(VecDeque::new()); } #[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))] fn max_concurrency() -> usize { if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { return threads; } } num_cpus::get_physical() } #[cfg(target_env = "sgx")] fn max_concurrency() -> usize { usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1) } #[cfg(target_arch = "wasm32")] fn max_concurrency() -> usize { 0 // wasm doesn't support threads yet } #[cfg(target_env = "sgx")] pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue { let q = { let mut qs = SGX_QUEUES.lock().unwrap(); qs.pop_front() // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return }; if let Some(q) = q { ThreadPool::run_worker(q); } TVMRetValue::default() } #[no_mangle] pub extern "C" fn TVMBackendParallelLaunch( cb: FTVMParallelLambda, cdata: *const c_void, num_task: usize, ) -> c_int { if max_concurrency() == 0 { let penv = TVMParallelGroupEnv { sync_handle: 0 as *mut c_void, num_task: 1, }; cb(0, &penv as *const _, cdata); } else { THREAD_POOL.with(|pool| { pool.launch(Job { cb: cb, cdata: cdata, req_num_tasks: num_task, pending: Arc::new(AtomicUsize::new(0)), }); }); } return 0; } #[cfg(target_env = "sgx")] pub(crate) fn sgx_join_threads() { extern "C" fn poison_pill( _task_id: usize, _penv: *const TVMParallelGroupEnv, _cdata: *const c_void, ) -> i32 { <i32>::min_value() } THREAD_POOL.with(|pool| { pool.launch(Job { cb: poison_pill, cdata: ptr::null(), req_num_tasks: 0, pending: Arc::new(AtomicUsize::new(0)), }); }); ocall_packed!("__sgx_thread_group_join__", 0); } // @see issue 988 for information on why this function is used. #[no_mangle] pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { let barrier: &Arc<Barrier> = unsafe { &*((*penv).sync_handle as *const Arc<Barrier>) }; barrier.wait(); } #[cfg(test)] mod tests { use std::{ptr, thread, time::Duration}; use super::*; #[test] fn test_max_concurrency() { env::set_var("TVM_NUM_THREADS", "42"); env::set_var("OMP_NUM_THREADS", "24"); assert_eq!(max_concurrency(), 42); env::remove_var("TVM_NUM_THREADS"); assert_eq!(max_concurrency(), 24); } extern "C" fn flambda( task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void, ) -> i32 { if cdata == ptr::null() { return 0; } unsafe { let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); thread::sleep(Duration::from_millis(50 * task_id as u64)); counter.fetch_add(1, Ordering::SeqCst); task_ids_sum.fetch_add(task_id, Ordering::SeqCst); assert_eq!((*penv).num_task, 3); } 0 } #[test] fn test_parallel_launch() { TVMBackendParallelLaunch(flambda, ptr::null(), 6); let counter = AtomicUsize::new(0); let task_ids_sum = AtomicUsize::new(0); let cdata = (counter, task_ids_sum); let num_tasks = 3; TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); assert_eq!( cdata.1.load(Ordering::SeqCst), (0..num_tasks).sum::<usize>() ); } }