Commit e2970b22 by Ehsan M. Kermani Committed by Nick Hynes

[RUST][FRONTEND] Add rust frontend v0.1 (#2292)

parent 18b2ebac
max_width = 100 max_width = 100
hard_tabs = false hard_tabs = false
tab_spaces = 2 tab_spaces = 4
newline_style = "Auto" newline_style = "Auto"
use_small_heuristics = "Default" use_small_heuristics = "Default"
indent_style = "Block" indent_style = "Block"
...@@ -38,7 +38,7 @@ trailing_comma = "Vertical" ...@@ -38,7 +38,7 @@ trailing_comma = "Vertical"
match_block_trailing_comma = false match_block_trailing_comma = false
blank_lines_upper_bound = 1 blank_lines_upper_bound = 1
blank_lines_lower_bound = 0 blank_lines_lower_bound = 0
edition = "2015" edition = "2018"
merge_derives = true merge_derives = true
use_try_shorthand = true use_try_shorthand = true
use_field_init_shorthand = false use_field_init_shorthand = false
...@@ -50,8 +50,8 @@ unstable_features = false ...@@ -50,8 +50,8 @@ unstable_features = false
disable_all_formatting = false disable_all_formatting = false
skip_children = false skip_children = false
hide_parse_errors = false hide_parse_errors = false
error_on_line_overflow = false error_on_line_overflow = true
error_on_unformatted = false error_on_unformatted = true
report_todo = "Never" report_todo = "Never"
report_fixme = "Never" report_fixme = "Never"
ignore = [] ignore = []
......
[package] [workspace]
name = "tvm" members = [
version = "0.1.0" "common",
license = "Apache-2.0" "runtime",
description = "TVM Rust runtime" "runtime/tests/test_tvm_basic",
repository = "https://github.com/dmlc/tvm" "runtime/tests/test_nnvm",
readme = "README.md" "frontend",
keywords = ["tvm", "nnvm"] "frontend/tests/basics",
categories = ["api-bindings", "science"] "frontend/tests/callback",
authors = ["TVM Contributors"] "frontend/examples/resnet"
]
[features]
default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
target
**/*.rs.bk
Cargo.lock
/tvm-sys/src/bindgen.rs
[package]
name = "tvm-common"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
[features]
runtime = []
frontend = ["tvm-sys"]
[dependencies]
error-chain = { version = "0.12.0", default-features = false }
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
error_chain! {
errors {
TryFromTVMArgValueError(expected: String, actual: String) {
description("mismatched types while converting from TVMArgValue")
display("expected `{}` but given `{}`", expected, actual)
}
TryFromTVMRetValueError(expected: String, actual: String) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
}
}
}
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.
#![crate_name = "tvm_common"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_imports)]
#![feature(box_syntax, try_from)]
#[macro_use]
extern crate error_chain;
/// Unified ffi module for both runtime and frontend crates.
pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
#[cfg(feature = "frontend")]
pub extern crate tvm_sys as ts;
#[cfg(feature = "runtime")]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
pub type BackendPackedCFunc = extern "C" fn(
args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
) -> c_int;
}
}
pub mod errors;
pub mod ty;
pub mod value;
pub use errors::*;
pub use ty::TVMTypeCode;
pub use value::{TVMArgValue, TVMRetValue, TVMValue};
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
//!
//! # Example
//!
//! ```
//! let dtype = TVMType::from("float");
//! println!("dtype is: {}", dtype);
//! ```
use std::{
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
};
/// TVM type codes.
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TVMTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kHandle = 3,
kNull = 4,
kTVMType = 5,
kTVMContext = 6,
kArrayHandle = 7,
kNodeHandle = 8,
kModuleHandle = 9,
kFuncHandle = 10,
kStr = 11,
kBytes = 12,
kNDArrayContainer = 13,
}
impl Default for TVMTypeCode {
fn default() -> Self {
TVMTypeCode::kDLInt
}
}
impl From<TVMTypeCode> for i64 {
fn from(arg: TVMTypeCode) -> i64 {
match arg {
TVMTypeCode::kDLInt => 0,
TVMTypeCode::kDLUInt => 1,
TVMTypeCode::kDLFloat => 2,
TVMTypeCode::kHandle => 3,
TVMTypeCode::kNull => 4,
TVMTypeCode::kTVMType => 5,
TVMTypeCode::kTVMContext => 6,
TVMTypeCode::kArrayHandle => 7,
TVMTypeCode::kNodeHandle => 8,
TVMTypeCode::kModuleHandle => 9,
TVMTypeCode::kFuncHandle => 10,
TVMTypeCode::kStr => 11,
TVMTypeCode::kBytes => 12,
TVMTypeCode::kNDArrayContainer => 13,
}
}
}
impl Into<TVMTypeCode> for i64 {
fn into(self) -> TVMTypeCode {
match self {
0 => TVMTypeCode::kDLInt,
1 => TVMTypeCode::kDLUInt,
2 => TVMTypeCode::kDLFloat,
3 => TVMTypeCode::kHandle,
4 => TVMTypeCode::kNull,
5 => TVMTypeCode::kTVMType,
6 => TVMTypeCode::kTVMContext,
7 => TVMTypeCode::kArrayHandle,
8 => TVMTypeCode::kNodeHandle,
9 => TVMTypeCode::kModuleHandle,
10 => TVMTypeCode::kFuncHandle,
11 => TVMTypeCode::kStr,
12 => TVMTypeCode::kBytes,
13 => TVMTypeCode::kNDArrayContainer,
_ => unreachable!(),
}
}
}
impl Display for TVMTypeCode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMTypeCode::kDLInt => "int",
TVMTypeCode::kDLUInt => "uint",
TVMTypeCode::kDLFloat => "float",
TVMTypeCode::kHandle => "handle",
TVMTypeCode::kNull => "null",
TVMTypeCode::kTVMType => "TVM type",
TVMTypeCode::kTVMContext => "TVM context",
TVMTypeCode::kArrayHandle => "Array handle",
TVMTypeCode::kNodeHandle => "Node handle",
TVMTypeCode::kModuleHandle => "Module handle",
TVMTypeCode::kFuncHandle => "Function handle",
TVMTypeCode::kStr => "string",
TVMTypeCode::kBytes => "bytes",
TVMTypeCode::kNDArrayContainer => "ndarray container",
}
)
}
}
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(usize, kDLInt);
impl_prim_type!(i64, kDLInt);
impl_prim_type!(i32, kDLInt);
impl_prim_type!(i16, kDLInt);
impl_prim_type!(i8, kDLInt);
impl_prim_type!(u64, kDLUInt);
impl_prim_type!(u32, kDLUInt);
impl_prim_type!(u16, kDLUInt);
impl_prim_type!(u8, kDLUInt);
impl_prim_type!(f64, kDLFloat);
impl_prim_type!(f32, kDLFloat);
impl_prim_type!(str, kStr);
impl_prim_type!(CStr, kStr);
impl_prim_type!(String, kStr);
impl_prim_type!(CString, kStr);
impl_prim_type!([u8], kBytes);
[package]
name = "tvm-sys"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
description = "Raw C API"
[build-dependencies]
bindgen = "0.37.4"
extern crate bindgen;
use std::path::PathBuf;
fn main() {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-search={}/build", env!("TVM_HOME"));
let bindings = bindgen::Builder::default()
.header(format!(
"{}/include/tvm/runtime/c_runtime_api.h",
env!("TVM_HOME")
))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME")))
.blacklist_type("max_align_t") // @see rust-bindgen#550
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.generate()
.expect("unable to generate bindings");
bindings
.write_to_file(PathBuf::from("src/bindgen.rs"))
.expect("can not write the bindings!");
}
#![allow(
non_camel_case_types,
non_snake_case,
non_upper_case_globals,
dead_code,
improper_ctypes
)]
include!("bindgen.rs");
target
**/*.rs.bk
Cargo.lock
/tests/basics/add_*
/examples/resnet/deploy_*
/examples/resnet/*.png
/examples/resnet/synset.*
[package]
name = "tvm-frontend"
version = "0.1.0"
license = "Apache-2.0"
description = "Rust frontend support for TVM"
repository = "https://github.com/dmlc/tvm"
homepage = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["rust", "tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
[lib]
name = "tvm_frontend"
crate-type = ["dylib"]
[dependencies]
error-chain = "0.12.0"
lazy_static = "1.1.0"
ndarray = "0.12.1"
num-traits = "0.2"
tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] }
[features]
blas = ["ndarray/blas"]
# TVM Runtime Frontend Support
This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly`
## What Does This Crate Offer?
Here is a major workflow
1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/)
2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators.
3. Deploy your models using **Rust** :heart:
### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example.
Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM
```python
block = get_model('resnet18_v1', pretrained=True)
sym, params = nnvm.frontend.from_mxnet(block)
# add the softmax layer for prediction
net = nnvm.sym.softmax(sym)
# compile the model
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
# same the model artifacts
lib.save(os.path.join(target_dir, "deploy_lib.o"))
cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
[os.path.join(target_dir, "deploy_lib.o")])
with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
fo.write(graph.json())
with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
```
Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image
![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true)
as demostrated in the following Rust snippet
```rust
let graph = fs::read_to_string("deploy_graph.json")?;
// load the built module
let lib = Module::load(&Path::new("deploy_lib.so"))?;
// get the global TVM graph runtime function
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
&lib,
&ctx.device_type,
&ctx.device_id
)?;
// get graph runtime module
let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?;
// get the registered `load_params` from runtime module
let ref load_param_fn = graph_runtime_module
.get_function("load_params", false)
.unwrap();
// parse parameters and convert to TVMByteArray
let params: Vec<u8> = fs::read("deploy_param.params")?;
let barr = TVMByteArray::from(&params);
// load the parameters
call_packed!(load_param_fn, &barr)?;
// get the set_input function
let ref set_input_fn = graph_runtime_module
.get_function("set_input", false)
.unwrap();
call_packed!(set_input_fn, "data", &input)?;
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
call_packed!(run_fn,)?;
// prepare to get the output
let output_shape = &mut [1, 1000];
let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
// get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module
.get_function("get_output", false)
.unwrap();
// execute the get output function
call_packed!(get_output_fn, &0, &output)?;
// flatten the output as Vec<f32>
let output = output.to_vec::<f32>()?;
```
and the model correctly predicts the input image as **tiger cat**.
## Installations
Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
*Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually.
## Supported TVM Functionalities
### Use TVM to Generate Shared Library
One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU.
```python
import os
import tvm
from tvm.contrib import cc
def test_add(target_dir):
if not tvm.module.enabled("cuda"):
print(f"skip {__file__} because cuda is not enabled...")
return
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
[os.path.join(target_dir, "add_gpu.o")])
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
sys.exit(-1)
test_add(sys.argv[1])
```
### Run the Generated Shared Library
The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust.
```rust
extern crate tvm_frontend as tvm;
use tvm::*;
fn main() {
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
arr.copy_from_buffer(data.as_mut_slice());
let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap();
let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap();
assert!(fadd.enabled("gpu"));
fadd.import_module(fadd_dep);
fadd.entry();
function::Builder::from(&mut fadd)
.arg(&arr)
.arg(&arr)
.set_output(&mut ret)?
.invoke()
.unwrap();
assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
}
```
**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by
`cargo:rustc-link-search=native=add_gpu`.
See the tests and examples custom `build.rs` for more details.
### Convert and Register a Rust Function as a TVM Packed Function
One can use `register_global_func!` macro to convert and register a Rust
function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows
```rust
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret = 0f32;
let shape = &mut [2];
for arg in args.iter() {
let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e).unwrap();
let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap();
ret += rnd.scalar_sum();
}
let ret_val = TVMRetValue::from(&ret);
Ok(ret_val)
}
}
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
arr.copy_from_buffer(data.as_mut_slice());
let mut registered = function::Builder::default();
let ret: f64 = registered
.get_function("sum", true)
.arg(&arr)
.arg(&arr)
.invoke()
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, 14f64);
}
```
[package]
name = "resnet"
version = "0.0.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
build = "build.rs"
[dependencies]
ndarray = "0.12.1"
tvm-frontend = { path = "../../" }
image = "0.20.1"
csv = "1"
## Resnet example
This end-to-end example shows how to:
* build `Resnet 18` with `tvm` and `nnvm` from Python
* use the provided Rust frontend API to test for an input image
To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).
* **Build the example**: `cargo build`
To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.
* **Run the example**: `cargo run`
use std::process::Command;
fn main() {
let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.output()
.expect("Failed to execute command");
assert!(
std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
"Could not prepare demo: {}",
String::from_utf8(output.stderr).unwrap().trim()
);
println!(
"cargo:rustc-link-search=native={}",
env!("CARGO_MANIFEST_DIR")
);
}
#!/usr/bin/env python3
import argparse
import csv
import logging
from os import path as osp
import sys
import numpy as np
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
import tvm
from tvm.contrib import graph_runtime, cc
import nnvm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description='Resnet build example')
aa = parser.add_argument
aa('--batch-size', type=int, default=1, help='input image batch size')
aa('--opt-level', type=int, default=3,
help='level of optimization. 0 is unoptimized and 3 is the highest level')
aa('--target', type=str, default='llvm', help='target context for compilation')
aa('--image-shape', type=str, default='3,224,224', help='input image dimensions')
aa('--image-name', type=str, default='cat.png', help='name of input image to download')
args = parser.parse_args()
target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
batch_size = args.batch_size
opt_level = args.opt_level
target = tvm.target.create(args.target)
image_shape = tuple(map(int, args.image_shape.split(",")))
data_shape = (batch_size,) + image_shape
def build(target_dir):
""" Compiles resnet18 with TVM"""
deploy_lib = osp.join(target_dir, 'deploy_lib.o')
if osp.exists(deploy_lib):
return
# download the pretrained resnet18 trained on imagenet1k dataset for
# image classification task
block = get_model('resnet18_v1', pretrained=True)
sym, params = nnvm.frontend.from_mxnet(block)
# add the softmax layer for prediction
net = nnvm.sym.softmax(sym)
# compile the model
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
# save the model artifacts
lib.save(deploy_lib)
cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
[osp.join(target_dir, "deploy_lib.o")])
with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
fo.write(graph.json())
with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
def download_img_labels():
""" Download an image and imagenet1k class labels for test"""
img_name = 'cat.png'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
'imagenet1000_clsid_to_human.txt'])
synset_name = 'synset.txt'
download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
download(synset_url, synset_name)
with open(synset_name) as fin:
synset = eval(fin.read())
with open("synset.csv", "w") as fout:
w = csv.writer(fout)
w.writerows(synset.items())
def test_build(target_dir):
""" Sanity check with random input"""
graph = open(osp.join(target_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
ctx = tvm.cpu()
module = graph_runtime.create(graph, lib, ctx)
module.load_params(params)
module.run(data=input_data)
out = module.get_output(0).asnumpy()
if __name__ == '__main__':
logger.info("building the model")
build(target_dir)
logger.info("build was successful")
logger.info("test the build artifacts")
test_build(target_dir)
logger.info("test was successful")
download_img_labels()
logger.info("image and synset downloads are successful")
#![feature(try_from)]
extern crate csv;
extern crate image;
extern crate ndarray;
extern crate tvm_frontend as tvm;
use std::{
collections::HashMap,
convert::TryInto,
fs::{self, File},
path::Path,
};
use image::{FilterType, GenericImageView};
use ndarray::{Array, ArrayD, Axis};
use tvm::*;
fn main() {
let ctx = TVMContext::cpu(0);
let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap();
println!("original image dimensions: {:?}", img.dimensions());
// for bigger size images, one needs to first resize to 256x256
// with `img.resize_exact` method and then `image.crop` to 224x224
let img = img.resize(224, 224, FilterType::Nearest).to_rgb();
println!("resized image dimensions: {:?}", img.dimensions());
let mut pixels: Vec<f32> = vec![];
for pixel in img.pixels() {
let tmp = pixel.data;
// normalize the RGB channels using mean, std of imagenet1k
let tmp = [
(tmp[0] as f32 - 123.0) / 58.395, // R
(tmp[1] as f32 - 117.0) / 57.12, // G
(tmp[2] as f32 - 104.0) / 57.375, // B
];
for e in &tmp {
pixels.push(*e);
}
}
let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap();
let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn();
// make arr shape as [1, 3, 224, 224] acceptable to resnet
let arr = arr.insert_axis(Axis(0));
// create input tensor from rust's ndarray
let input =
NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap();
println!(
"input size is {:?}",
input.shape().expect("cannot get the input shape")
);
let graph =
fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap();
// load the built module
let lib = Module::load(&Path::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/deploy_lib.so"
)))
.unwrap();
// get the global TVM graph runtime function
let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
&lib,
&ctx.device_type,
&ctx.device_id
)
.unwrap();
// get graph runtime module
let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap();
// get the registered `load_params` from runtime module
let ref load_param_fn = graph_runtime_module
.get_function("load_params", false)
.unwrap();
// parse parameters and convert to TVMByteArray
let params: Vec<u8> =
fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap();
let barr = TVMByteArray::from(&params);
// load the parameters
call_packed!(load_param_fn, &barr).unwrap();
// get the set_input function
let ref set_input_fn = graph_runtime_module
.get_function("set_input", false)
.unwrap();
call_packed!(set_input_fn, "data", &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
call_packed!(run_fn,).unwrap();
// prepare to get the output
let output_shape = &mut [1, 1000];
let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
// get the `get_output` function from runtime module
let ref get_output_fn = graph_runtime_module
.get_function("get_output", false)
.unwrap();
// execute the get output function
call_packed!(get_output_fn, &0, &output).unwrap();
// flatten the output as Vec<f32>
let output = output.to_vec::<f32>().unwrap();
// find the maximum entry in the output and its index
let mut argmax = -1;
let mut max_prob = 0.;
for i in 0..output.len() {
if output[i] > max_prob {
max_prob = output[i];
argmax = i as i32;
}
}
// create a hash map of (class id, class name)
let mut synset: HashMap<i32, String> = HashMap::new();
let file = File::open("synset.csv").unwrap();
let mut rdr = csv::ReaderBuilder::new()
.has_headers(true)
.from_reader(file);
for result in rdr.records() {
let record = result.unwrap();
let id: i32 = record[0].parse().unwrap();
let cls = record[1].to_string();
synset.insert(id, cls);
}
println!(
"input image belongs to the class `{}` with probability {}",
synset
.get(&argmax)
.expect("cannot find the class id for argmax"),
max_prob
);
}
//! Provides [`TVMByteArray`] used for passing the model parameters
//! (stored as byte-array) to a runtime module.
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use std::os::raw::c_char;
use crate::ts;
/// A struct holding TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello".to_vec();
/// let barr = TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
/// ```
#[derive(Debug, Clone)]
pub struct TVMByteArray {
pub(crate) inner: ts::TVMByteArray,
}
impl TVMByteArray {
pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray {
TVMByteArray { inner: barr }
}
/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.inner.size
}
/// Gets the underlying byte-array as `Vec<i8>`
pub fn data(&self) -> Vec<i8> {
unsafe {
let sz = self.len();
let mut ret_buf = Vec::with_capacity(sz);
ret_buf.set_len(sz);
self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz);
ret_buf
}
}
}
impl<'a> From<&'a Vec<u8>> for TVMByteArray {
fn from(arg: &Vec<u8>) -> Self {
let barr = ts::TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
};
TVMByteArray::new(barr)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), vec![1i8, 2, 3]);
let v = b"hello".to_vec();
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
}
}
//! Provides [`TVMContext`] and related device specific queries.
//!
//! Create a new context by device type (cpu is 1) and device id.
//!
//! # Example
//!
//! ```
//! let ctx = TVMContext::new(1, 0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! ```
//!
//! Or from a supported device name.
//!
//! ```
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! ```
use std::{
fmt::{self, Display, Formatter},
os::raw::c_void,
ptr,
};
use crate::{function, ts, Result};
/// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm).
///
/// ## Example
///
/// ```
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
///```
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TVMDeviceType(pub usize);
impl Default for TVMDeviceType {
/// default device is cpu.
fn default() -> Self {
TVMDeviceType(1)
}
}
impl From<TVMDeviceType> for ts::DLDeviceType {
fn from(device_type: TVMDeviceType) -> Self {
match device_type.0 {
1 => ts::DLDeviceType_kDLCPU,
2 => ts::DLDeviceType_kDLGPU,
3 => ts::DLDeviceType_kDLCPUPinned,
4 => ts::DLDeviceType_kDLOpenCL,
7 => ts::DLDeviceType_kDLVulkan,
8 => ts::DLDeviceType_kDLMetal,
9 => ts::DLDeviceType_kDLVPI,
10 => ts::DLDeviceType_kDLROCM,
12 => ts::DLDeviceType_kDLExtDev,
_ => panic!("device type not found!"),
}
}
}
impl From<ts::DLDeviceType> for TVMDeviceType {
fn from(device_type: ts::DLDeviceType) -> Self {
match device_type {
ts::DLDeviceType_kDLCPU => TVMDeviceType(1),
ts::DLDeviceType_kDLGPU => TVMDeviceType(2),
ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3),
ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4),
ts::DLDeviceType_kDLVulkan => TVMDeviceType(7),
ts::DLDeviceType_kDLMetal => TVMDeviceType(8),
ts::DLDeviceType_kDLVPI => TVMDeviceType(9),
ts::DLDeviceType_kDLROCM => TVMDeviceType(10),
ts::DLDeviceType_kDLExtDev => TVMDeviceType(12),
_ => panic!("device type not found!"),
}
}
}
impl Display for TVMDeviceType {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMDeviceType(1) => "cpu",
TVMDeviceType(2) => "gpu",
TVMDeviceType(3) => "cpu_pinned",
TVMDeviceType(4) => "opencl",
TVMDeviceType(8) => "meta",
TVMDeviceType(9) => "vpi",
TVMDeviceType(10) => "rocm",
TVMDeviceType(_) => "rpc",
}
)
}
}
impl<'a> From<&'a str> for TVMDeviceType {
fn from(type_str: &'a str) -> Self {
match type_str {
"cpu" => TVMDeviceType(1),
"llvm" => TVMDeviceType(1),
"stackvm" => TVMDeviceType(1),
"gpu" => TVMDeviceType(2),
"cuda" => TVMDeviceType(2),
"nvptx" => TVMDeviceType(2),
"cl" => TVMDeviceType(4),
"opencl" => TVMDeviceType(4),
"metal" => TVMDeviceType(8),
"vpi" => TVMDeviceType(9),
"rocm" => TVMDeviceType(10),
_ => panic!("{:?} not supported!", type_str),
}
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
///
/// ```
/// let ctx = TVMContext::from("gpu");
/// assert!(ctx.exist());
///
/// ```
///
/// It is possible to query the underlying context as follows
///
/// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
/// println!("compute version: {}", ctx.compute_version());
/// ```
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
pub struct TVMContext {
/// Supported device types
pub device_type: TVMDeviceType,
/// Device id
pub device_id: usize,
}
impl TVMContext {
/// Creates context from device type and id.
pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self {
TVMContext {
device_type: device_type,
device_id: device_id,
}
}
}
macro_rules! impl_ctxs {
($(($ctx:ident, $dldevt:expr));+) => {
$(
impl TVMContext {
pub fn $ctx(device_id: usize) -> Self {
Self::new(TVMDeviceType($dldevt), device_id)
}
}
)+
};
}
impl_ctxs!((cpu, 1);
(gpu, 2);
(nvptx, 2);
(cuda, 2);
(cpu_pinned, 3);
(cl, 4);
(opencl, 4);
(metal, 8);
(vpi, 9);
(rocm, 10);
(opengl, 11);
(ext_dev, 12));
impl<'a> From<&'a str> for TVMContext {
fn from(target: &str) -> Self {
TVMContext::new(TVMDeviceType::from(target), 0)
}
}
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
.expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let ret = call_packed!(func, &dt, &self.device_id, &0)
.unwrap()
.prim_value;
ret != 0
}
/// Synchronize the context stream.
pub fn sync(&self) -> Result<()> {
check_call!(ts::TVMSynchronize(
self.device_type.0 as i32,
self.device_id as i32,
ptr::null_mut() as *mut c_void
));
Ok(())
}
}
macro_rules! impl_device_attrs {
($(($attr_name:ident, $attr_kind:expr));+) => {
$(
impl TVMContext {
pub fn $attr_name(&self) -> usize {
let func = function::Function::get("_GetDeviceAttr", true /* is_global */)
.expect("API function always exists");
let dt = self.device_type.0 as usize;
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
let ret = function::Builder::from(func)
.args(&[dt, self.device_id, $attr_kind])
.invoke()
.unwrap();
ret.prim_value as usize
}
}
)+
};
}
impl_device_attrs!((max_threads_per_block, 1);
(warp_size, 2);
(max_shared_memory_per_block, 3);
(compute_version, 4);
(device_name, 5);
(max_clock_rate, 6);
(multi_processor_count, 7);
(max_thread_dimensions, 8));
impl From<ts::DLContext> for TVMContext {
fn from(ctx: ts::DLContext) -> Self {
TVMContext {
device_type: TVMDeviceType::from(ctx.device_type),
device_id: ctx.device_id as usize,
}
}
}
impl From<TVMContext> for ts::DLContext {
fn from(ctx: TVMContext) -> Self {
ts::DLContext {
device_type: ctx.device_type.into(),
device_id: ctx.device_id as i32,
}
}
}
impl Display for TVMContext {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}({})", self.device_type, self.device_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn context() {
let ctx = TVMContext::cpu(0);
println!("ctx: {}", ctx);
let default_ctx = TVMContext::new(TVMDeviceType(1), 0);
assert_eq!(ctx.clone(), default_ctx);
assert_ne!(ctx, TVMContext::gpu(0));
let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0);
assert_eq!(str_ctx.clone(), str_ctx);
assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0));
}
#[test]
fn sync() {
let ctx = TVMContext::cpu(0);
assert!(ctx.sync().is_ok())
}
}
//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types.
use std::{ffi, option};
use crate::{common_errors, rust_ndarray};
error_chain! {
errors {
EmptyArray {
description("cannot convert from an empty array")
}
NullHandle(name: String) {
description("null handle")
display("requested `{}` handle is null", name)
}
FunctionNotFound {
description("function not found")
display("function was not set in `function::Builder`")
}
TypeMismatch(expected: String, found: String) {
description("type mismatch!")
display("expected type `{}`, but found `{}`", expected, found)
}
MissingShapeError {
description("ndarray `shape()` returns `None`")
display("called `Option::unwrap()` on a `None` value")
}
AtMostOneReturn {
description("TVM functions accept at most one return value")
}
}
foreign_links {
ShapeError(rust_ndarray::ShapeError);
NulError(ffi::NulError);
IntoStringError(ffi::IntoStringError);
CommonError(common_errors::Error);
}
}
impl From<option::NoneError> for Error {
fn from(_err: option::NoneError) -> Self {
ErrorKind::MissingShapeError.into()
}
}
//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems.
//!
//! This crate provides an idiomatic Rust API for TVM runtime frontend.
//!
//! One particular use case is that given optimized deep learning model artifacts,
//! (compiled with TVM) which include a shared library
//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
//! in Rust idomatically to create a TVM Graph Runtime and
//! run the model for some inputs and get the
//! desired predictions *all in Rust*.
//!
//! Checkout the `examples` repository for more details.
#![crate_name = "tvm_frontend"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_unsafe)]
#![feature(
try_from,
try_trait,
fn_traits,
unboxed_closures,
box_syntax,
option_replace
)]
#[macro_use]
extern crate error_chain;
extern crate tvm_common as common;
#[macro_use]
extern crate lazy_static;
extern crate ndarray as rust_ndarray;
extern crate num_traits;
use std::{
ffi::{CStr, CString},
str,
};
use crate::common::ffi::ts;
// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
($e:expr) => {{
if unsafe { $e } != 0 {
panic!("{}", $crate::get_last_error());
}
}};
}
/// Gets the last error message.
pub fn get_last_error() -> &'static str {
unsafe {
match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
}
}
pub(crate) fn set_last_error(err: &Error) {
let c_string = CString::new(err.to_string()).unwrap();
unsafe {
ts::TVMAPISetLastError(c_string.as_ptr());
}
}
#[macro_use]
pub mod function;
pub mod bytearray;
pub mod context;
pub mod errors;
pub mod module;
pub mod ndarray;
pub mod ty;
pub mod value;
pub use crate::{
bytearray::TVMByteArray,
common::{
errors as common_errors,
ty::TVMTypeCode,
value::{TVMArgValue, TVMRetValue, TVMValue},
},
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
ty::TVMType,
};
/// Outputs the current TVM version.
pub fn version() -> &'static str {
match str::from_utf8(ts::TVM_VERSION) {
Ok(s) => s,
Err(_) => "Invalid UTF-8 string",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn print_version() {
println!("TVM version: {}", version());
}
#[test]
fn set_error() {
let err = ErrorKind::EmptyArray;
set_last_error(&err.into());
assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
}
}
//! Provides the [`Module`] type and methods for working with runtime TVM modules.
use std::{
convert::TryInto,
ffi::CString,
os::raw::{c_char, c_int},
path::Path,
ptr,
};
use crate::ts;
use crate::{function::Function, ErrorKind, Result};
const ENTRY_FUNC: &'static str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
/// Also [`is_released`] shows whether the module is dropped or not.
///
/// [`entry_func`]:struct.Module.html#method.entry_func
/// [`is_released`]:struct.Module.html#method.is_released
#[derive(Debug, Clone)]
pub struct Module {
pub(crate) handle: ts::TVMModuleHandle,
is_released: bool,
entry_func: Option<Function>,
}
impl Module {
pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self {
Self {
handle,
is_released,
entry_func: None,
}
}
pub fn entry(&mut self) -> Option<&Function> {
if self.entry_func.is_none() {
self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
}
self.entry_func.as_ref()
}
/// Gets a function by name from a registered module.
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
let name = CString::new(name)?;
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
check_call!(ts::TVMModGetFunction(
self.handle,
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
));
if fhandle.is_null() {
bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
} else {
Ok(Function::new(fhandle, false, false))
}
}
/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
}
/// Loads a module shared library from path.
pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module> {
let ext = path.as_ref().extension()?.to_str()?;
let func = Function::get("module._LoadFromFile", true /* is_global */)
.expect("API function always exists");
let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?;
Ok(ret)
}
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled", true /* is_global */)
.expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap();
ret != 0
}
/// Returns the underlying module handle.
pub fn handle(&self) -> ts::TVMModuleHandle {
self.handle
}
/// Returns true if the underlying module has been dropped and false otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
}
impl Drop for Module {
fn drop(&mut self) {
if !self.is_released {
check_call!(ts::TVMModFree(self.handle));
self.is_released = true;
}
}
}
//! This module implements the required conversions from Rust types to TVM types.
//!
//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32)
//! and 64-bits pointers are supported.
use std::{
fmt::{self, Display, Formatter},
ops::{Deref, DerefMut},
};
use crate::ts;
use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode};
macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl From<$type> for TVMTypeCode {
fn from(_arg: $type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}
impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}
impl_prim_type!(TVMDeviceType, kDLInt);
impl_prim_type!(TVMContext, kTVMContext);
impl_prim_type!(TVMType, kTVMType);
impl_prim_type!(Function, kFuncHandle);
impl_prim_type!(Module, kModuleHandle);
impl_prim_type!(NDArray, kArrayHandle);
impl_prim_type!(TVMByteArray, kBytes);
/// See the [module-level documentation](../ty/index.html) for more details.
///
/// Wrapper around underlying TVMType
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TVMType {
// inner fields are (code: u8, bits: u8, lanes: u16)
pub inner: ts::TVMType,
}
impl TVMType {
pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self {
TVMType {
inner: ts::TVMType {
code: type_code,
bits: bits,
lanes: lanes,
},
}
}
}
/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl<'a> From<&'a str> for TVMType {
fn from(type_str: &'a str) -> Self {
if type_str == "bool" {
return TVMType::new(1, 1, 1);
}
let mut type_lanes = type_str.split("x");
let typ = type_lanes.next().expect("Missing dtype");
let lanes = type_lanes
.next()
.map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l)))
.unwrap_or(1);
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(
name,
u8::from_str_radix(bits_str, 10)
.expect(&format!("Bad dtype bits: {}", bits_str)),
)
}
None => (typ, 32),
};
let type_code = match type_name {
"int" => 0,
"uint" => 1,
"float" => 2,
"handle" => 3,
_ => unimplemented!(),
};
TVMType::new(type_code, bits, lanes)
}
}
impl Display for TVMType {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let ts::TVMType { code, bits, lanes } = self.inner;
if bits == 1 && lanes == 1 {
return write!(f, "bool");
}
let mut tcode_str = match code {
0 => "int",
1 => "uint",
2 => "float",
4 => "handle",
_ => "Unknown",
}
.to_string();
tcode_str += &bits.to_string();
if lanes > 1 {
tcode_str += &format!("x{}", lanes.to_string());
}
f.write_str(&tcode_str)
}
}
impl From<TVMType> for ts::DLDataType {
fn from(dtype: TVMType) -> Self {
dtype.inner
}
}
impl From<ts::DLDataType> for TVMType {
fn from(dtype: ts::DLDataType) -> Self {
Self::new(dtype.code, dtype.bits, dtype.lanes)
}
}
impl Deref for TVMType {
type Target = ts::TVMType;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TVMType {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
//! This module implements [`TVMArgValue`] and [`TVMRetValue`] types
//! and their conversions needed for the types used in frontend crate.
//! `TVMRetValue` is the owned version of `TVMPODValue`.
use std::{convert::TryFrom, mem, os::raw::c_void};
use crate::{
common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext,
TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue,
};
macro_rules! impl_tvm_val_from_handle {
($($ty:ty),+) => {
$(
impl<'a> From<&'a $ty> for TVMValue {
fn from(arg: &$ty) -> Self {
let inner = ts::TVMValue {
v_handle: arg.handle as *mut _ as *mut c_void,
};
Self::new(inner)
}
}
)+
}
}
impl_tvm_val_from_handle!(Module, Function, NDArray);
impl<'a> From<&'a TVMType> for TVMValue {
fn from(ty: &TVMType) -> Self {
let inner = ts::TVMValue { v_type: ty.inner };
Self::new(inner)
}
}
impl<'a> From<&'a TVMContext> for TVMValue {
fn from(ctx: &TVMContext) -> Self {
let inner = ts::TVMValue {
v_ctx: ctx.clone().into(),
};
Self::new(inner)
}
}
impl<'a> From<&'a TVMDeviceType> for TVMValue {
fn from(dev: &TVMDeviceType) -> Self {
let inner = ts::TVMValue {
v_int64: dev.0 as i64,
};
Self::new(inner)
}
}
impl<'a> From<&'a TVMByteArray> for TVMValue {
fn from(barr: &TVMByteArray) -> Self {
let inner = ts::TVMValue {
v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void,
};
Self::new(inner)
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kArrayHandle {
let handle = unsafe { arg.value.inner.v_handle };
let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) };
Ok(Self::new(arr_handle, true))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(NDArray).to_string(),
arg.type_code.to_string()
))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kModuleHandle {
let handle = unsafe { arg.value.inner.v_handle };
Ok(Self::new(handle, false))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(Module).to_string(),
arg.type_code.to_string()
))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kBytes {
unsafe {
let barr_ptr =
mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle);
Ok(Self::new(*barr_ptr))
}
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMByteArray).to_string(),
arg.type_code.to_string()
))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kTVMType {
let ty = unsafe { arg.value.inner.v_type };
Ok(TVMType::from(ty))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMType).to_string(),
arg.type_code.to_string()
))
}
}
}
impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext {
type Error = Error;
fn try_from(arg: &TVMArgValue<'a>) -> Result<Self> {
if arg.type_code == TVMTypeCode::kTVMContext {
let ty = unsafe { arg.value.inner.v_ctx };
Ok(TVMContext::from(ty))
} else {
bail!(ErrorKind::TryFromTVMArgValueError(
stringify!(TVMContext).to_string(),
arg.type_code.to_string()
))
}
}
}
macro_rules! impl_boxed_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box val,
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code.to_string()
))
}
}
}
};
}
impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType);
impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext);
impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes);
impl TryFrom<TVMRetValue> for Module {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Module> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMModuleHandle>() {
Ok(Module::new(*handle, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kModuleHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
impl TryFrom<TVMRetValue> for Function {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Function> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMFunctionHandle>() {
Ok(Function::new(*handle, false, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kFuncHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
impl TryFrom<TVMRetValue> for NDArray {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<NDArray> {
if let Ok(handle) = ret.box_value.downcast::<ts::TVMArrayHandle>() {
Ok(NDArray::new(*handle, false))
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!(TVMTypeCode::kArrayHandle).to_string(),
ret.type_code.to_string()
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryInto;
#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = TVMByteArray::from(&w);
let tvm: TVMByteArray = TVMRetValue::from(v).try_into().unwrap();
assert_eq!(tvm.data(), w.iter().map(|e| *e as i8).collect::<Vec<i8>>());
}
#[test]
fn ty() {
let t = TVMType::from("int32");
let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t);
}
#[test]
fn ctx() {
let c = TVMContext::from("gpu");
let tvm: TVMContext = TVMRetValue::from(c).try_into().unwrap();
assert_eq!(tvm, c);
}
}
/target
**/*.rs.bk
Cargo.lock
*.o
*.so
*.ptx
*.json
[package]
name = "basics"
version = "0.0.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
build = "build.rs"
[dependencies]
ndarray = "0.12.1"
tvm-frontend = { path = "../../" }
[features]
default = ["cpu"]
cpu = []
gpu = []
fn main() {
let out_dir = std::env::var("OUT_DIR").unwrap();
let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py"))
.args(&[
if cfg!(feature = "cpu") {
"llvm"
} else {
"cuda"
},
&std::env::var("OUT_DIR").unwrap(),
])
.output()
.expect("Failed to execute command");
assert!(
std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(),
"Could not build tvm lib: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
println!("cargo:rustc-link-search=native={}", out_dir);
}
extern crate ndarray as rust_ndarray;
extern crate tvm_frontend as tvm;
use tvm::*;
fn main() {
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
let (ctx, ctx_name) = if cfg!(feature = "cpu") {
(TVMContext::cpu(0), "cpu")
} else {
(TVMContext::gpu(0), "gpu")
};
let dtype = TVMType::from("float32");
let mut arr = NDArray::empty(shape, ctx, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let mut ret = NDArray::empty(shape, ctx, dtype);
let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
if !fadd.enabled(ctx_name) {
return;
}
if cfg!(feature = "gpu") {
fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap());
}
function::Builder::from(&mut fadd)
.arg(&arr)
.arg(&arr)
.set_output(&mut ret)
.unwrap()
.invoke()
.unwrap();
assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
}
#!/usr/bin/env python3
import os.path as osp
import sys
import tvm
from tvm.contrib import cc
def main(target, out_dir):
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
if target == 'cuda':
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis('blockIdx.x'))
s[C].bind(tx, tvm.thread_axis('threadIdx.x'))
fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd')
fadd.save(osp.join(out_dir, 'test_add.o'))
if target == 'cuda':
fadd.imported_modules[0].save(os.path.join(out_dir, 'test_add.ptx'))
cc.create_shared(
osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')])
if __name__ == '__main__':
main(sys.argv[1], sys.argv[2])
[package]
name = "callback"
version = "0.0.0"
authors = ["TVM Contributors"]
[dependencies]
ndarray = "0.12.1"
tvm-frontend = { path = "../../" }
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate ndarray as rust_ndarray;
#[macro_use]
extern crate tvm_frontend as tvm;
use rust_ndarray::ArrayD;
use std::convert::{TryFrom, TryInto};
use tvm::*;
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret = 0f32;
let shape = &mut [2];
for arg in args.iter() {
let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
let arg: NDArray = arg.try_into()?;
let arr = arg.copy_to_ndarray(e)?;
let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?;
ret += rnd.scalar_sum();
}
Ok(TVMRetValue::from(ret))
}
}
let shape = &mut [2];
let mut data = vec![3f32, 4.0];
let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
arr.copy_from_buffer(data.as_mut_slice());
let mut registered = function::Builder::default();
let ret: f32 = registered
.get_function("sum", true)
.arg(&arr)
.arg(&arr)
.invoke()
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, 14f32);
}
#![feature(extern_crate_item_prelude, panic_info_message)]
#![allow(unused_imports)]
use std::panic;
#[macro_use]
extern crate tvm_frontend as tvm;
use tvm::*;
fn main() {
register_global_func! {
fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue> {
Err(ErrorKind::TypeMismatch(
format!("{}", "i64".to_string()),
format!("{}", "f64".to_string()),
).into())
}
}
let mut registered = function::Builder::default();
registered.get_function("error", true);
assert!(registered.func.is_some());
registered.args(&[10, 20]);
println!("expected error message is:");
panic::set_hook(Box::new(|panic_info| {
if let Some(msg) = panic_info.message() {
println!("{:?}", msg);
}
if let Some(location) = panic_info.location() {
println!(
"panic occurred in file '{}' at line {}",
location.file(),
location.line()
);
} else {
println!("panic occurred but can't get location information");
}
}));
let _result = registered.invoke();
}
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
fn main() {
register_global_func! {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret = 0.0;
for arg in args.iter() {
let val: f64 = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(&ret))
}
}
let mut registered = function::Builder::default();
registered.get_function("sum", true);
assert!(registered.func.is_some());
let ret: f64 = registered
.args(&[10.0f64, 20.0, 30.0])
.invoke()
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, 60f64);
}
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
fn main() {
fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret = 0i64;
for arg in args.iter() {
let val: i64 = arg.try_into()?;
ret += val;
}
Ok(TVMRetValue::from(&ret))
}
tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
let mut registered = function::Builder::default();
registered.get_function("mysum", true);
assert!(registered.func.is_some());
let ret: i64 = registered
.args(&[10, 20, 30])
.invoke()
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, 60);
}
#![feature(extern_crate_item_prelude, try_from)]
#![allow(unused_imports)]
#[macro_use]
extern crate tvm_frontend as tvm;
use std::convert::TryInto;
use tvm::*;
// FIXME
fn main() {
register_global_func! {
fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret = "".to_string();
for arg in args.iter() {
let val: String = arg.try_into()?;
ret += val.as_str();
}
Ok(TVMRetValue::from(ret))
}
}
let mut registered = function::Builder::default();
registered.get_function("concate_str", true);
assert!(registered.func.is_some());
let a = "a".to_string();
let b = "b".to_string();
let c = "c".to_string();
let ret: String = registered
.args(&[a, b, c])
.invoke()
.unwrap()
.try_into()
.unwrap();
assert_eq!(ret, "abc".to_owned());
}
language: rust
rust:
- nightly
matrix:
fast_finish: true
[package]
name = "tvm-runtime"
version = "0.1.0"
license = "Apache-2.0"
description = "A static TVM runtime"
repository = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
[features]
default = ["nom/std"]
sgx = ["nom/alloc"]
[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"
tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] }
[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
...@@ -3,50 +3,50 @@ use alloc::alloc::{self, Layout}; ...@@ -3,50 +3,50 @@ use alloc::alloc::{self, Layout};
#[cfg(not(target_env = "sgx"))] #[cfg(not(target_env = "sgx"))]
use std::alloc::{self, Layout}; use std::alloc::{self, Layout};
use errors::*; use crate::errors::*;
const DEFAULT_ALIGN_BYTES: usize = 4; const DEFAULT_ALIGN_BYTES: usize = 4;
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub struct Allocation { pub struct Allocation {
layout: Layout, layout: Layout,
ptr: *mut u8, ptr: *mut u8,
} }
impl Allocation { impl Allocation {
/// Allocates a chunk of memory of `size` bytes with optional alignment. /// Allocates a chunk of memory of `size` bytes with optional alignment.
pub fn new(size: usize, align: Option<usize>) -> Result<Self> { pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
let layout = Layout::from_size_align(size, alignment)?; let layout = Layout::from_size_align(size, alignment)?;
let ptr = unsafe { alloc::alloc(layout.clone()) }; let ptr = unsafe { alloc::alloc(layout.clone()) };
if ptr.is_null() { if ptr.is_null() {
alloc::handle_alloc_error(layout); alloc::handle_alloc_error(layout);
}
Ok(Self {
ptr: ptr,
layout: layout,
})
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
/// Returns the size of the Allocation in bytes.
pub fn size(&self) -> usize {
self.layout.size()
}
/// Returns the byte alignment of the Allocation.
pub fn align(&self) -> usize {
self.layout.align()
} }
Ok(Self {
ptr: ptr,
layout: layout,
})
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
/// Returns the size of the Allocation in bytes.
pub fn size(&self) -> usize {
self.layout.size()
}
/// Returns the byte alignment of the Allocation.
pub fn align(&self) -> usize {
self.layout.align()
}
} }
impl Drop for Allocation { impl Drop for Allocation {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
alloc::dealloc(self.ptr, self.layout.clone()); alloc::dealloc(self.ptr, self.layout.clone());
}
} }
}
} }
...@@ -4,16 +4,12 @@ use alloc::alloc; ...@@ -4,16 +4,12 @@ use alloc::alloc;
use std::alloc; use std::alloc;
use std::num; use std::num;
use crate::common::errors as common_errors;
use ndarray; use ndarray;
use serde_json; use serde_json;
error_chain! { error_chain! {
errors { errors {
TryFromTVMRetValueError(expected: String, actual: i64) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but was `{}`", expected, actual)
}
GraphFormatError(msg: String) { GraphFormatError(msg: String) {
description("unable to load graph") description("unable to load graph")
display("could not load graph json: {}", msg) display("could not load graph json: {}", msg)
...@@ -29,11 +25,12 @@ error_chain! { ...@@ -29,11 +25,12 @@ error_chain! {
GraphDeserialize(serde_json::Error); GraphDeserialize(serde_json::Error);
ParseInt(num::ParseIntError); ParseInt(num::ParseIntError);
ShapeError(ndarray::ShapeError); ShapeError(ndarray::ShapeError);
CommonError(common_errors::Error);
} }
} }
impl From<alloc::LayoutErr> for Error { impl From<alloc::LayoutErr> for Error {
fn from(_err: alloc::LayoutErr) -> Error { fn from(_err: alloc::LayoutErr) -> Error {
Error::from_kind(ErrorKind::Msg("Layout error".to_string())) Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
} }
} }
...@@ -10,13 +10,13 @@ ...@@ -10,13 +10,13 @@
//! For examples of use, please refer to the multi-file tests in the `tests` directory. //! For examples of use, please refer to the multi-file tests in the `tests` directory.
#![feature( #![feature(
alloc, alloc,
allocator_api, allocator_api,
box_syntax, box_syntax,
fn_traits, fn_traits,
try_from, try_from,
unboxed_closures, unboxed_closures,
vec_remove_item vec_remove_item
)] )]
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
...@@ -39,29 +39,36 @@ extern crate serde; ...@@ -39,29 +39,36 @@ extern crate serde;
#[macro_use] #[macro_use]
extern crate serde_derive; extern crate serde_derive;
extern crate serde_json; extern crate serde_json;
extern crate tvm_common as common;
pub mod ffi { mod allocator;
#![allow( mod array;
non_camel_case_types, pub mod errors;
non_snake_case, mod module;
non_upper_case_globals, #[macro_use]
unused mod packed_func;
)] mod graph;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
pub mod runtime { pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue};
use std::os::raw::{c_char, c_int, c_void};
include!(concat!( pub use self::{
env!("CARGO_MANIFEST_DIR"), array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*,
"/src/runtime/c_runtime_api.rs" };
));
pub type BackendPackedCFunc = #[cfg(target_env = "sgx")]
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; use self::sgx::ocall_packed_func;
}
}
pub mod errors;
pub mod runtime;
pub use errors::*; #[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) {
#[cfg(not(target_env = "sgx"))]
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
use std::{ use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
}; };
use ffi::runtime::BackendPackedCFunc; use crate::{
use runtime::packed_func::{wrap_backend_packed_func, PackedFunc}; ffi::runtime::BackendPackedCFunc,
packed_func::{wrap_backend_packed_func, PackedFunc},
};
pub trait Module { pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>; fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
} }
pub struct SystemLibModule; pub struct SystemLibModule;
lazy_static! { lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> = static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
Mutex::new(HashMap::new()); Mutex::new(HashMap::new());
} }
impl Module for SystemLibModule { impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> { fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc> {
SYSTEM_LIB_FUNCTIONS SYSTEM_LIB_FUNCTIONS
.lock() .lock()
.unwrap() .unwrap()
.get(name.as_ref()) .get(name.as_ref())
.map(|func| wrap_backend_packed_func(func.to_owned())) .map(|func| wrap_backend_packed_func(func.to_owned()))
} }
} }
impl Default for SystemLibModule { impl Default for SystemLibModule {
fn default() -> Self { fn default() -> Self {
SystemLibModule {} SystemLibModule {}
} }
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMBackendRegisterSystemLibSymbol( pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
cname: *const c_char, cname: *const c_char,
func: BackendPackedCFunc, func: BackendPackedCFunc,
) -> i32 { ) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS SYSTEM_LIB_FUNCTIONS
.lock() .lock()
.unwrap() .unwrap()
.insert(name.to_string(), func); .insert(name.to_string(), func);
return 0; return 0;
} }
use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use crate::ffi::runtime::{
BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue,
};
use super::DLTensor;
use crate::{
common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue},
errors::*,
};
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
let raw = _TVMValue {
v_handle: arr as *mut _ as *mut c_void,
};
TVMArgValue {
value: TVMValue::new(raw),
type_code: TVMTypeCode::kArrayHandle,
lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode::kArrayHandle
|| val.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode::kArrayHandle,
TVMTypeCode::kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode::kNDArrayContainer,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode::kArrayHandle
|| ret.type_code == TVMTypeCode::kNDArrayContainer,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) };
Ok(DLTensor::new(dlt).into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args.iter()
.map(|ref arg| arg.value.inner)
.collect::<Vec<_TVMValue>>()
.as_ptr(),
args.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
use std::{ use std::{
ffi::CString, ffi::CString,
os::raw::{c_char, c_int}, os::raw::{c_char, c_int},
}; };
use errors::Result; use errors::Result;
...@@ -11,50 +11,48 @@ pub use runtime::threading::tvm_run_worker as run_worker; ...@@ -11,50 +11,48 @@ pub use runtime::threading::tvm_run_worker as run_worker;
#[macro_export] #[macro_export]
macro_rules! tvm_ocall { macro_rules! tvm_ocall {
($func: expr) => { ($func: expr) => {
match $func { match $func {
0 => Ok(()), 0 => Ok(()),
err => Err(format!("SGX error: {}", err)), err => Err(format!("SGX error: {}", err)),
} }
}; };
} }
pub type SgxStatus = u32; pub type SgxStatus = u32;
#[cfg(target_env = "sgx")] #[cfg(target_env = "sgx")]
extern "C" { extern "C" {
fn tvm_ocall_packed_func( fn tvm_ocall_packed_func(
name: *const c_char, name: *const c_char,
arg_values: *const TVMValue, arg_values: *const TVMValue,
type_codes: *const c_int, type_codes: *const c_int,
num_args: c_int, num_args: c_int,
ret_val: *mut TVMValue, ret_val: *mut TVMValue,
ret_type_code: *mut c_int, ret_type_code: *mut c_int,
) -> SgxStatus; ) -> SgxStatus;
} }
pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> { pub fn ocall_packed_func<S: AsRef<str>>(fn_name: S, args: &[TVMArgValue]) -> Result<TVMRetValue> {
let mut ret_val = TVMValue { v_int64: 0 }; let mut ret_val = TVMValue { v_int64: 0 };
let ret_type_code = 0i64; let ret_type_code = 0i64;
unsafe { unsafe {
tvm_ocall!(tvm_ocall_packed_func( tvm_ocall!(tvm_ocall_packed_func(
CString::new(fn_name.as_ref()).unwrap().as_ptr(), CString::new(fn_name.as_ref()).unwrap().as_ptr(),
args args.iter()
.iter() .map(|ref arg| arg.value)
.map(|ref arg| arg.value) .collect::<Vec<TVMValue>>()
.collect::<Vec<TVMValue>>() .as_ptr(),
.as_ptr(), args.iter()
args .map(|ref arg| arg.type_code as i32)
.iter() .collect::<Vec<i32>>()
.map(|ref arg| arg.type_code as i32) .as_ptr() as *const i32,
.collect::<Vec<i32>>() args.len() as i32,
.as_ptr() as *const i32, &mut ret_val as *mut TVMValue,
args.len() as i32, &mut (ret_type_code as i32) as *mut c_int,
&mut ret_val as *mut TVMValue, ))?;
&mut (ret_type_code as i32) as *mut c_int, }
))?; Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
}
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64))
} }
#[macro_export] #[macro_export]
...@@ -70,13 +68,13 @@ macro_rules! ocall_packed { ...@@ -70,13 +68,13 @@ macro_rules! ocall_packed {
} }
pub fn shutdown() { pub fn shutdown() {
if env!("TVM_NUM_THREADS") != "0" { if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads() sgx_join_threads()
} }
} }
impl Drop for SystemLibModule { impl Drop for SystemLibModule {
fn drop(&mut self) { fn drop(&mut self) {
shutdown() shutdown()
} }
} }
use std::{ use std::{
cell::RefCell, cell::RefCell,
os::raw::{c_int, c_void}, os::raw::{c_int, c_void},
ptr, ptr,
}; };
use super::allocator::Allocation; use super::allocator::Allocation;
use errors::*; use crate::errors::*;
const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
struct WorkspacePool { struct WorkspacePool {
workspaces: Vec<Allocation>, workspaces: Vec<Allocation>,
free: Vec<usize>, free: Vec<usize>,
in_use: Vec<usize>, in_use: Vec<usize>,
} }
impl WorkspacePool { impl WorkspacePool {
fn new() -> Self { fn new() -> Self {
WorkspacePool { WorkspacePool {
workspaces: Vec::new(), workspaces: Vec::new(),
free: Vec::new(), free: Vec::new(),
in_use: Vec::new(), in_use: Vec::new(),
}
} }
}
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}
fn alloc(&mut self, size: usize) -> Result<*mut u8> { fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
if self.free.len() == 0 { self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
return self.alloc_new(size); self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
} }
let idx = self
.free fn alloc(&mut self, size: usize) -> Result<*mut u8> {
.iter() if self.free.len() == 0 {
.fold(None, |cur_ws_idx: Option<usize>, &idx| { return self.alloc_new(size);
let ws_size = self.workspaces[idx].size(); }
if !ws_size >= size { let idx = self
return cur_ws_idx; .free
.iter()
.fold(None, |cur_ws_idx: Option<usize>, &idx| {
let ws_size = self.workspaces[idx].size();
if !ws_size >= size {
return cur_ws_idx;
}
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size {
true => idx,
false => cur_idx,
})
})
});
match idx {
Some(idx) => {
self.free.remove_item(&idx).unwrap();
self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr())
}
None => self.alloc_new(size),
} }
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size {
true => idx,
false => cur_idx,
})
})
});
match idx {
Some(idx) => {
self.free.remove_item(&idx).unwrap();
self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr())
}
None => self.alloc_new(size),
} }
}
fn free(&mut self, ptr: *mut u8) -> Result<()> { fn free(&mut self, ptr: *mut u8) -> Result<()> {
let mut ws_idx = None; let mut ws_idx = None;
for i in 0..self.in_use.len() { for i in 0..self.in_use.len() {
let idx = self.in_use[i]; let idx = self.in_use[i];
if self.workspaces[idx].as_mut_ptr() == ptr { if self.workspaces[idx].as_mut_ptr() == ptr {
self.in_use.remove(i); self.in_use.remove(i);
ws_idx = Some(idx); ws_idx = Some(idx);
break; break;
} }
}
Ok(self
.free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?))
} }
Ok(
self
.free
.push(ws_idx.ok_or("Tried to free nonexistent workspace.")?),
)
}
} }
thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new())); thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
...@@ -84,36 +82,36 @@ const WORKSPACE_PAGE_SIZE: usize = 4 << 10; ...@@ -84,36 +82,36 @@ const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMBackendAllocWorkspace( pub extern "C" fn TVMBackendAllocWorkspace(
_device_type: c_int, _device_type: c_int,
_device_id: c_int, _device_id: c_int,
size: u64, size: u64,
_dtype_code_hint: c_int, _dtype_code_hint: c_int,
_dtype_bits_hint: c_int, _dtype_bits_hint: c_int,
) -> *mut c_void { ) -> *mut c_void {
let nbytes = if size == 0 { let nbytes = if size == 0 {
WORKSPACE_PAGE_SIZE WORKSPACE_PAGE_SIZE
} else { } else {
size as usize size as usize
}; };
WORKSPACE_POOL.with(|pool_cell| { WORKSPACE_POOL.with(|pool_cell| {
pool_cell pool_cell
.borrow_mut() .borrow_mut()
.alloc(nbytes as usize) .alloc(nbytes as usize)
.unwrap_or(ptr::null_mut()) as *mut c_void .unwrap_or(ptr::null_mut()) as *mut c_void
}) })
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn TVMBackendFreeWorkspace( pub extern "C" fn TVMBackendFreeWorkspace(
_device_type: c_int, _device_type: c_int,
_device_id: c_int, _device_id: c_int,
ptr: *mut c_void, ptr: *mut c_void,
) -> c_int { ) -> c_int {
WORKSPACE_POOL.with(|pool_cell| { WORKSPACE_POOL.with(|pool_cell| {
(match pool_cell.borrow_mut().free(ptr as *mut u8) { (match pool_cell.borrow_mut().free(ptr as *mut u8) {
Ok(()) => 0, Ok(()) => 0,
Err(_) => -1, Err(_) => -1,
}) as c_int }) as c_int
}); });
return 0; return 0;
} }
#!/usr/bin/env python3
"""Builds a simple NNVM graph for testing.""" """Builds a simple NNVM graph for testing."""
from os import path as osp from os import path as osp
......
...@@ -3,37 +3,37 @@ ...@@ -3,37 +3,37 @@
extern crate serde; extern crate serde;
extern crate serde_json; extern crate serde_json;
extern crate tvm; extern crate tvm_runtime;
use std::{convert::TryFrom, fs, io::Read}; use std::{convert::TryFrom, fs, io::Read};
use tvm::runtime::Graph; use tvm_runtime::Graph;
#[test] #[test]
fn test_load_graph() { fn test_load_graph() {
let mut params_bytes = Vec::new(); let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?") .expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
.read_to_end(&mut params_bytes) .read_to_end(&mut params_bytes)
.unwrap(); .unwrap();
let _params = tvm::runtime::load_param_dict(&params_bytes); let _params = tvm_runtime::load_param_dict(&params_bytes);
let graph = Graph::try_from( let graph = Graph::try_from(
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
) )
.unwrap(); .unwrap();
assert_eq!(graph.nodes[3].op, "tvm_op"); assert_eq!(graph.nodes[3].op, "tvm_op");
assert_eq!( assert_eq!(
graph.nodes[3] graph.nodes[3]
.attrs .attrs
.as_ref() .as_ref()
.unwrap() .unwrap()
.get("func_name") .get("func_name")
.unwrap(), .unwrap(),
"fuse_dense" "fuse_dense"
); );
assert_eq!(graph.nodes[5].inputs[0].index, 0); assert_eq!(graph.nodes[5].inputs[0].index, 0);
assert_eq!(graph.nodes[6].inputs[0].index, 1); assert_eq!(graph.nodes[6].inputs[0].index, 1);
assert_eq!(graph.heads.len(), 2); assert_eq!(graph.heads.len(), 2);
} }
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
name = "test-nnvm" name = "test-nnvm"
version = "0.0.0" version = "0.0.0"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray = "0.11.2" ndarray = "0.11.2"
tvm = { path = "../../" }
serde = "1.0.59" serde = "1.0.59"
serde_json = "1.0.17" serde_json = "1.0.17"
tvm-runtime = { path = "../../" }
[build-dependencies] [build-dependencies]
ar = "0.6.0" ar = "0.6.0"
extern crate ar; extern crate ar;
use std::{ use std::{env, fs::File, path::Path, process::Command};
env,
fs::File,
path::{Path, PathBuf},
process::Command,
};
use ar::Builder; use ar::Builder;
fn main() { fn main() {
let out_dir = env::var("OUT_DIR").unwrap(); let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!( let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"), env!("CARGO_MANIFEST_DIR"),
"/src/build_test_graph.py" "/src/build_test_graph.py"
)) ))
.arg(&out_dir) .arg(&out_dir)
.output() .output()
.expect("Failed to execute command"); .expect("Failed to execute command");
assert!( assert!(
Path::new(&format!("{}/graph.o", out_dir)).exists(), Path::new(&format!("{}/graph.o", out_dir)).exists(),
"Could not build graph lib: {}", "Could not build graph lib: {}",
String::from_utf8(output.stderr) String::from_utf8(output.stderr)
.unwrap() .unwrap()
.trim() .trim()
.split("\n") .split("\n")
.last() .last()
.unwrap_or("") .unwrap_or("")
); );
let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect(); let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap());
let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect(); builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
builder.append_path(in_path.to_str().unwrap()).unwrap();
println!("cargo:rustc-link-lib=static=graph"); println!("cargo:rustc-link-lib=static=graph");
println!("cargo:rustc-link-search=native={}", out_dir); println!("cargo:rustc-link-search=native={}", out_dir);
} }
...@@ -23,6 +23,7 @@ def _get_model(dshape): ...@@ -23,6 +23,7 @@ def _get_model(dshape):
def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
if isinstance(graph, sym.Symbol): if isinstance(graph, sym.Symbol):
graph = nnvm.graph.create(graph) graph = nnvm.graph.create(graph)
ishapes, _ = graph_util.infer_shape(graph, **input_shapes) ishapes, _ = graph_util.infer_shape(graph, **input_shapes)
param_shapes = dict(zip(graph.index.input_names, ishapes)) param_shapes = dict(zip(graph.index.input_names, ishapes))
np.random.seed(seed) np.random.seed(seed)
...@@ -40,6 +41,7 @@ def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): ...@@ -40,6 +41,7 @@ def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10):
initializer(param, init_value) initializer(param, init_value)
# init_value /= init_value.sum() + 1e-10 # init_value /= init_value.sum() + 1e-10
params[param] = tvm.nd.array(init_value) params[param] = tvm.nd.array(init_value)
return params return params
def main(): def main():
...@@ -56,6 +58,7 @@ def main(): ...@@ -56,6 +58,7 @@ def main():
lib.save(osp.join(sys.argv[1], 'graph.o')) lib.save(osp.join(sys.argv[1], 'graph.o'))
with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet:
f_resnet.write(graph.json()) f_resnet.write(graph.json())
with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params:
f_params.write(nnvm.compiler.save_param_dict(params)) f_params.write(nnvm.compiler.save_param_dict(params))
......
...@@ -5,76 +5,78 @@ extern crate ndarray; ...@@ -5,76 +5,78 @@ extern crate ndarray;
extern crate serde; extern crate serde;
extern crate serde_json; extern crate serde_json;
extern crate tvm; extern crate tvm_runtime;
use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
use ndarray::Array; use ndarray::Array;
use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
const BATCH_SIZE: usize = 4; const BATCH_SIZE: usize = 4;
const IN_DIM: usize = 8; const IN_DIM: usize = 8;
macro_rules! check_sum { macro_rules! check_sum {
($e:expr, $a:ident, $b:ident) => { ($e:expr, $a:ident, $b:ident) => {
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
check_sum!(a, $b); check_sum!(a, $b);
}; };
($e:expr, $a:expr, $b:ident) => { ($e:expr, $a:expr, $b:ident) => {
let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
check_sum!(a, $b); check_sum!(a, $b);
}; };
($a:ident, $b:ident) => { ($a:ident, $b:ident) => {
let a_sum: f32 = $a.scalar_sum(); let a_sum: f32 = $a.scalar_sum();
let b_sum: f32 = $b.scalar_sum(); let b_sum: f32 = $b.scalar_sum();
assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
}; };
} }
fn main() { fn main() {
let syslib = SystemLibModule::default(); let syslib = SystemLibModule::default();
let mut params_bytes = Vec::new(); let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) fs::File::open(concat!(env!("OUT_DIR"), "/graph.params"))
.unwrap() .unwrap()
.read_to_end(&mut params_bytes) .read_to_end(&mut params_bytes)
.unwrap(); .unwrap();
let params = tvm::runtime::load_param_dict(&params_bytes) let params = tvm_runtime::load_param_dict(&params_bytes)
.unwrap() .unwrap()
.into_iter() .into_iter()
.map(|(k, v)| (k, v.to_owned())) .map(|(k, v)| (k, v.to_owned()))
.collect::<HashMap<String, Tensor<'static>>>(); .collect::<HashMap<String, Tensor<'static>>>();
let graph = let graph =
Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap(); Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap())
let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); .unwrap();
let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
let x = Array::from_shape_vec( let x = Array::from_shape_vec(
(BATCH_SIZE, IN_DIM), (BATCH_SIZE, IN_DIM),
(0..BATCH_SIZE * IN_DIM) (0..BATCH_SIZE * IN_DIM)
.map(|x| x as f32) .map(|x| x as f32)
.collect::<Vec<f32>>(), .collect::<Vec<f32>>(),
).unwrap(); )
let w = Array::try_from(params.get("dense0_weight").unwrap())
.unwrap()
.into_shape((IN_DIM * 2, IN_DIM))
.unwrap(); .unwrap();
let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); let w = Array::try_from(params.get("dense0_weight").unwrap())
let dense = x.dot(&w.t()) + &b; .unwrap()
let left = dense.slice(s![.., 0..IN_DIM]); .into_shape((IN_DIM * 2, IN_DIM))
let right = dense.slice(s![.., IN_DIM..]); .unwrap();
let expected_o0 = &left + 1f32; let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap();
let expected_o1 = &right - 1f32; let dense = x.dot(&w.t()) + &b;
let left = dense.slice(s![.., 0..IN_DIM]);
let right = dense.slice(s![.., IN_DIM..]);
let expected_o0 = &left + 1f32;
let expected_o1 = &right - 1f32;
exec.load_params(params); exec.load_params(params);
exec.set_input("data", x.clone().into()); exec.set_input("data", (&x).into());
check_sum!(exec, data, x); check_sum!(exec, data, x);
check_sum!(exec, dense0_weight, w); check_sum!(exec, dense0_weight, w);
check_sum!(exec, dense0_bias, b); check_sum!(exec, dense0_bias, b);
exec.run(); exec.run();
check_sum!(exec, 0, expected_o0); check_sum!(exec, 0, expected_o0);
check_sum!(exec, 1, expected_o1); check_sum!(exec, 1, expected_o1);
check_sum!(exec, 2, dense); check_sum!(exec, 2, dense);
} }
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
name = "test-tvm-basic" name = "test-tvm-basic"
version = "0.0.0" version = "0.0.0"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"] authors = ["TVM Contributors"]
[dependencies] [dependencies]
ndarray = "0.11.2" ndarray = "0.11.2"
tvm = { path = "../../" } tvm-runtime = { path = "../../" }
[build-dependencies] [build-dependencies]
ar = "0.6.0" ar = "0.6.0"
extern crate ar; extern crate ar;
use std::{env, path::PathBuf, process::Command}; use std::{env, path::Path, process::Command};
use ar::Builder; use ar::Builder;
use std::fs::File; use std::fs::File;
fn main() { fn main() {
let out_dir = env::var("OUT_DIR").unwrap(); let out_dir = env::var("OUT_DIR").unwrap();
let output = Command::new(concat!( let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"), env!("CARGO_MANIFEST_DIR"),
"/src/build_test_lib.py" "/src/build_test_lib.py"
)).arg(&out_dir) ))
.arg(&out_dir)
.output() .output()
.expect("Failed to execute command"); .expect("Failed to execute command");
if output.stderr.len() > 0 { assert!(
panic!(String::from_utf8(output.stderr).unwrap()); Path::new(&format!("{}/test.o", out_dir)).exists(),
} "Could not build tvm lib: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
let in_path: PathBuf = [&out_dir, "test.o"].iter().collect(); let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap());
let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect(); builder.append_path(format!("{}/test.o", out_dir)).unwrap();
let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap());
builder.append_path(in_path.to_str().unwrap()).unwrap();
println!("cargo:rustc-link-lib=static=test"); println!("cargo:rustc-link-lib=static=test");
println!("cargo:rustc-link-search=native={}", out_dir); println!("cargo:rustc-link-search=native={}", out_dir);
} }
extern crate ndarray; extern crate ndarray;
#[macro_use] #[macro_use]
extern crate tvm; extern crate tvm_runtime;
use ndarray::Array; use ndarray::Array;
use tvm::{ use tvm_runtime::{DLTensor, Module, SystemLibModule};
ffi::runtime::DLTensor,
runtime::{Module, SystemLibModule},
};
fn main() { fn main() {
let syslib = SystemLibModule::default(); let syslib = SystemLibModule::default();
let add = syslib let add = syslib
.get_function("default_function") .get_function("default_function")
.expect("main function not found"); .expect("main function not found");
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]); let mut c = Array::from_vec(vec![0f32; 4]);
let e = Array::from_vec(vec![2f32, 2., 4., 4.]); let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
let mut a_dl: DLTensor = (&mut a).into(); let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into(); let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into(); let mut c_dl: DLTensor = (&mut c).into();
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl);
assert!(c.all_close(&e, 1e-8f32)); assert!(c.all_close(&e, 1e-8f32));
} }
mod allocator;
mod array;
mod module;
#[macro_use]
mod packed_func;
mod graph;
#[cfg(target_env = "sgx")]
#[macro_use]
pub mod sgx;
mod threading;
mod workspace;
use std::os::raw::c_char;
pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*};
#[cfg(target_env = "sgx")]
use self::sgx::ocall_packed_func;
#[no_mangle]
pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) {
#[cfg(not(target_env = "sgx"))]
unsafe {
panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap());
}
#[cfg(target_env = "sgx")]
ocall_packed!("__sgx_set_last_error__", cmsg);
}
use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void};
use super::Tensor;
use ffi::runtime::{
BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor,
TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMTypeCode_kNDArrayContainer, TVMValue,
};
use errors::*;
pub type PackedFunc = Box<Fn(&[TVMArgValue]) -> TVMRetValue + Send + Sync>;
/// Calls a packed function and returns a `TVMRetValue`.
///
/// # Example
///
/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
#[macro_export]
macro_rules! call_packed {
($fn:expr, $($args:expr),+) => {
$fn(&[$($args.into(),)+])
};
($fn:expr) => {
$fn(&Vec::new())
};
}
/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
/// to obtain a `TVMArgValue` is automatically via `call_packed!`.
#[derive(Clone, Copy)]
pub struct TVMArgValue<'a> {
_lifetime: PhantomData<&'a ()>,
pub(crate) value: TVMValue,
pub(crate) type_code: i64,
}
impl<'a> TVMArgValue<'a> {
pub fn new(value: TVMValue, type_code: i64) -> Self {
TVMArgValue {
_lifetime: PhantomData,
value: value,
type_code: type_code,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode.
macro_rules! impl_prim_tvm_arg {
($type:ty, $field:ident, $code:expr, $as:ty) => {
impl<'a> From<$type> for TVMArgValue<'a> {
fn from(val: $type) -> Self {
TVMArgValue {
value: TVMValue { $field: val as $as },
type_code: $code as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for $type {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == $code as i64,
"Could not downcast arg. Expected `{}`, got `{}`",
$code,
val.type_code
);
Ok(unsafe { val.value.$field as $type })
}
}
};
($type:ty, $field:ident, $code:expr) => {
impl_prim_tvm_arg!($type, $field, $code, $type);
};
($type:ty,v_int64) => {
impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64);
};
($type:ty,v_float64) => {
impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64);
};
}
impl_prim_tvm_arg!(f32, v_float64);
impl_prim_tvm_arg!(f64, v_float64);
impl_prim_tvm_arg!(i8, v_int64);
impl_prim_tvm_arg!(u8, v_int64);
impl_prim_tvm_arg!(i32, v_int64);
impl_prim_tvm_arg!(u32, v_int64);
impl_prim_tvm_arg!(i64, v_int64);
impl_prim_tvm_arg!(u64, v_int64);
/// Creates a conversion to a `TVMArgValue` for an object handle.
impl<'a, T> From<*const T> for TVMArgValue<'a> {
fn from(ptr: *const T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut T as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
/// Creates a conversion to a `TVMArgValue` for a mutable object handle.
impl<'a, T> From<*mut T> for TVMArgValue<'a> {
fn from(ptr: *mut T) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: ptr as *mut c_void,
},
type_code: TVMTypeCode_kHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a mut DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *mut _ as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> From<&'a DLTensor> for TVMArgValue<'a> {
fn from(arr: &'a DLTensor) -> Self {
TVMArgValue {
value: TVMValue {
v_handle: arr as *const _ as *mut DLTensor as *mut c_void,
},
type_code: TVMTypeCode_kArrayHandle as i64,
_lifetime: PhantomData,
}
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for Tensor<'a> {
type Error = Error;
fn try_from(val: TVMArgValue<'a>) -> Result<Self> {
ensure!(
val.type_code == TVMTypeCode_kArrayHandle as i64
|| val.type_code == TVMTypeCode_kNDArrayContainer as i64,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
val.type_code,
);
let dlt = unsafe { *(val.value.v_handle as *mut DLTensor as *const DLTensor) };
Ok(dlt.into())
}
}
/// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
/// Can be downcasted using `try_from` if it contains the desired type.
///
/// # Example
///
/// ```
/// let a = 42u32;
/// let b: i64 = TVMRetValue::from(a).try_into().unwrap();
///
/// let s = "hello, world!";
/// let t: TVMRetValue = s.into();
/// assert_eq!(String::try_from(t).unwrap(), s);
/// ```
pub struct TVMRetValue {
/// A primitive return value, if any.
prim_value: u64,
/// An object return value, if any.
box_value: Box<Any>,
/// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use.
type_code: i64,
}
#[cfg(target_env = "sgx")]
impl TVMRetValue {
pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
unsafe {
Self {
prim_value: match type_code {
0 | 1 => value.v_int64 as u64,
2 => value.v_float64 as u64,
3 | 7 | 8 | 9 | 10 => value.v_handle as u64,
11 | 12 => value.v_str as u64,
_ => 0,
} as u64,
box_value: box (),
type_code: type_code,
}
}
}
pub fn into_tvm_value(self) -> (TVMValue, i64) {
let val = match self.type_code {
0 | 1 => TVMValue {
v_int64: self.prim_value.clone() as i64,
},
2 => TVMValue {
v_float64: self.prim_value.clone() as f64,
},
3 | 7 | 8 | 9 | 10 | 13 => TVMValue {
v_handle: Box::into_raw(self.box_value) as *mut c_void,
},
11 | 12 => TVMValue {
v_str: Box::into_raw(self.box_value) as *const _,
},
_ => unreachable!(),
};
(val, self.type_code)
}
}
impl Default for TVMRetValue {
fn default() -> Self {
TVMRetValue {
prim_value: 0,
box_value: box (),
type_code: 0,
}
}
}
macro_rules! impl_prim_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: val as u64,
box_value: box (),
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if ret.type_code == $code {
Ok(ret.prim_value as $type)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code
))
}
}
}
};
}
macro_rules! impl_boxed_ret_value {
($type:ty, $code:expr) => {
impl From<$type> for TVMRetValue {
fn from(val: $type) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box val,
type_code: $code,
}
}
}
impl TryFrom<TVMRetValue> for $type {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<$type> {
if let Ok(val) = ret.box_value.downcast::<$type>() {
Ok(*val)
} else {
bail!(ErrorKind::TryFromTVMRetValueError(
stringify!($type).to_string(),
ret.type_code
))
}
}
}
};
}
impl_prim_ret_value!(i8, 0);
impl_prim_ret_value!(u8, 1);
impl_prim_ret_value!(i16, 0);
impl_prim_ret_value!(u16, 1);
impl_prim_ret_value!(i32, 0);
impl_prim_ret_value!(u32, 1);
impl_prim_ret_value!(f32, 2);
impl_prim_ret_value!(i64, 0);
impl_prim_ret_value!(u64, 1);
impl_prim_ret_value!(f64, 2);
impl_prim_ret_value!(isize, 0);
impl_prim_ret_value!(usize, 1);
impl_boxed_ret_value!(String, 11);
impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue {
fn from(val: &'t Tensor<'a>) -> Self {
TVMRetValue {
prim_value: 0,
box_value: box DLTensor::from(val),
type_code: TVMTypeCode_kNDArrayContainer as i64,
}
}
}
impl<'a> TryFrom<TVMRetValue> for Tensor<'a> {
type Error = Error;
fn try_from(ret: TVMRetValue) -> Result<Self> {
ensure!(
ret.type_code == TVMTypeCode_kArrayHandle as i64
|| ret.type_code == TVMTypeCode_kNDArrayContainer as i64,
"Could not downcast arg. Expected `{}` or `{}`, but got `{}`",
TVMTypeCode_kArrayHandle,
TVMTypeCode_kNDArrayContainer,
ret.type_code,
);
let dlt = unsafe { *(ret.prim_value as *mut DLTensor as *const DLTensor) };
Ok(dlt.into())
}
}
// @see `WrapPackedFunc` in `llvm_module.cc`.
pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc {
box move |args: &[TVMArgValue]| {
func(
args
.iter()
.map(|ref arg| arg.value)
.collect::<Vec<TVMValue>>()
.as_ptr(),
args
.iter()
.map(|ref arg| arg.type_code as i32)
.collect::<Vec<i32>>()
.as_ptr() as *const i32,
args.len() as i32,
);
TVMRetValue::default()
}
}
...@@ -2,24 +2,60 @@ ...@@ -2,24 +2,60 @@
set -e set -e
export LD_LIBRARY_PATH=lib:$LD_LIBRARY_PATH export TVM_HOME="$(git rev-parse --show-toplevel)"
tvm_root="$(git rev-parse --show-toplevel)" export LD_LIBRARY_PATH="$TVM_HOME/lib":"$TVM_HOME/build":"$TVM_HOME/nnvm":$LD_LIBRARY_PATH
export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python" export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/nnvm/python":"$TVM_HOME/topi/python"
export RUST_DIR="$TVM_HOME/rust"
#cd rust cd $RUST_DIR
#cargo fmt -- --check cargo fmt -- --check
# test common
cd $RUST_DIR/common
cargo build --features runtime
cargo test --features runtime --tests
cargo build --features frontend
cargo test --features frontend --tests
# test runtime
cd $RUST_DIR/runtime
# run basic tests # run basic tests
#python3 tests/build_model.py python3 tests/build_model.py
#cargo test --tests cargo test --tests
# run TVM module test # run TVM module test
#cd tests/test_tvm_basic cd tests/test_tvm_basic
#cargo run cargo run
#cd - cd -
# run NNVM graph test # run NNVM graph test
#cd tests/test_nnvm cd tests/test_nnvm
#cargo run cargo run
#cd - cd -
# test frontend
cd $RUST_DIR/frontend
cargo test --tests -- --test-threads=1
# run basic tests on cpu
cd tests/basics
cargo build --features cpu
cargo run --features cpu
# uncomment when have more CI resources
# cargo build --features gpu
# cargo run --features gpu
# fi
cd -
# run callback tests separately: https://discuss.tvm.ai/t/are-global-functions-need-to-be-accessed-in-separate-processes/1075
cd tests/callback
cargo build
cargo run --bin int
cargo run --bin float
cargo run --bin array
cargo run --bin string
cd -
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment