Commit 8d3b392d by Ehsan M. Kermani Committed by Nick Hynes

[RUST][FRONTEND] Fix resnet example (#3000)

Due to the previous changes the frontend resnet example failed to build.  So this patch 

1) fixes it 
2) adds ~~a local `run_tests.sh` to remedy non-existence of MXNet CI (used in python build example)~~ the example build to CI with random weights and a flag for pretrained resnet weights

Please review: @tqchen @nhynes @kazimuth
parent 151ccdf9
......@@ -155,7 +155,7 @@ TVMPODValue! {
Bytes(val) => {
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes)
}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) }
}
}
......@@ -260,12 +260,24 @@ impl<'a> From<&'a str> for TVMArgValue<'a> {
}
}
impl<'a> From<String> for TVMArgValue<'a> {
fn from(s: String) -> Self {
Self::String(CString::new(s).unwrap())
}
}
impl<'a> From<&'a CStr> for TVMArgValue<'a> {
fn from(s: &'a CStr) -> Self {
Self::Str(s)
}
}
impl<'a> From<&'a TVMByteArray> for TVMArgValue<'a> {
fn from(s: &'a TVMByteArray) -> Self {
Self::Bytes(s)
}
}
impl<'a> TryFrom<TVMArgValue<'a>> for &'a str {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
......
......@@ -17,7 +17,7 @@
* under the License.
*/
use std::str::FromStr;
use std::{os::raw::c_char, str::FromStr};
use failure::Error;
......@@ -157,17 +157,57 @@ impl_tvm_context!(
DLDeviceType_kDLExtDev: [ext_dev]
);
/// A struct holding TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello";
/// let barr = TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ```
impl TVMByteArray {
/// Gets the underlying byte-array
pub fn data(&self) -> &'static [u8] {
unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) }
}
/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.size
}
/// Converts the underlying byte-array to `Vec<u8>`
pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec()
}
}
impl<'a> From<&'a [u8]> for TVMByteArray {
fn from(bytes: &[u8]) -> Self {
Self {
data: bytes.as_ptr() as *const i8,
size: bytes.len(),
// Needs AsRef for Vec
impl<T: AsRef<[u8]>> From<T> for TVMByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
let v = b"hello";
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
}
}
......@@ -21,12 +21,25 @@ This end-to-end example shows how to:
* build `Resnet 18` with `tvm` and `nnvm` from Python
* use the provided Rust frontend API to test for an input image
To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
To run the example with pretrained resnet weights, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).
* **Build the example**: `cargo build`
* **Build the example**: `cargo build
To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.
* **Run the example**: `cargo run`
Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with
```
let output = Command::new("python")
.arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
.arg(&format!("--pretrained"))
.output()
.expect("Failed to execute command");
```
Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`!
......@@ -17,16 +17,23 @@
* under the License.
*/
use std::process::Command;
use std::{path::Path, process::Command};
fn main() {
let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
let output = Command::new("python3")
.arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
.output()
.expect("Failed to execute command");
assert!(
std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(),
"Could not prepare demo: {}",
String::from_utf8(output.stderr).unwrap().trim()
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
println!(
"cargo:rustc-link-search=native={}",
......
......@@ -24,19 +24,18 @@ import sys
import numpy as np
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime, cc
import nnvm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description='Resnet build example')
aa = parser.add_argument
aa('--build-dir', type=str, required=True, help='directory to put the build artifacts')
aa('--pretrained', action='store_true', help='use a pretrained resnet')
aa('--batch-size', type=int, default=1, help='input image batch size')
aa('--opt-level', type=int, default=3,
help='level of optimization. 0 is unoptimized and 3 is the highest level')
......@@ -45,7 +44,7 @@ aa('--image-shape', type=str, default='3,224,224', help='input image dimensions'
aa('--image-name', type=str, default='cat.png', help='name of input image to download')
args = parser.parse_args()
target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
build_dir = args.build_dir
batch_size = args.batch_size
opt_level = args.opt_level
target = tvm.target.create(args.target)
......@@ -57,30 +56,42 @@ def build(target_dir):
deploy_lib = osp.join(target_dir, 'deploy_lib.o')
if osp.exists(deploy_lib):
return
# download the pretrained resnet18 trained on imagenet1k dataset for
# image classification task
block = get_model('resnet18_v1', pretrained=True)
sym, params = nnvm.frontend.from_mxnet(block)
# add the softmax layer for prediction
net = nnvm.sym.softmax(sym)
if args.pretrained:
# needs mxnet installed
from mxnet.gluon.model_zoo.vision import get_model
# if `--pretrained` is enabled, it downloads a pretrained
# resnet18 trained on imagenet1k dataset for image classification task
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
# we want a probability so add a softmax operator
net = relay.Function(net.params, relay.nn.softmax(net.body),
None, net.type_params, net.attrs)
else:
# use random weights from relay.testing
net, params = relay.testing.resnet.get_workload(
num_layers=18, batch_size=batch_size, image_shape=image_shape)
# compile the model
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build_module.build(net, target, params=params)
# save the model artifacts
lib.save(deploy_lib)
cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
[osp.join(target_dir, "deploy_lib.o")])
with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
fo.write(graph.json())
fo.write(graph)
with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
fo.write(relay.save_param_dict(params))
def download_img_labels():
""" Download an image and imagenet1k class labels for test"""
from mxnet.gluon.utils import download
img_name = 'cat.png'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
......@@ -97,11 +108,11 @@ def download_img_labels():
w = csv.writer(fout)
w.writerows(synset.items())
def test_build(target_dir):
def test_build(build_dir):
""" Sanity check with random input"""
graph = open(osp.join(target_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
graph = open(osp.join(build_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(build_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
ctx = tvm.cpu()
module = graph_runtime.create(graph, lib, ctx)
......@@ -112,10 +123,11 @@ def test_build(target_dir):
if __name__ == '__main__':
logger.info("building the model")
build(target_dir)
build(build_dir)
logger.info("build was successful")
logger.info("test the build artifacts")
test_build(target_dir)
test_build(build_dir)
logger.info("test was successful")
download_img_labels()
logger.info("image and synset downloads are successful")
if args.pretrained:
download_img_labels()
logger.info("image and synset downloads are successful")
......@@ -84,7 +84,7 @@ fn main() {
let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
graph,
&lib,
&ctx.device_type,
&ctx.device_id
......@@ -107,8 +107,7 @@ fn main() {
.get_function("set_input", false)
.unwrap();
let data_str = "data".to_string();
call_packed!(set_input_fn, &data_str, &input).unwrap();
call_packed!(set_input_fn, "data".to_string(), &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
//! Provides [`TVMByteArray`] used for passing the model parameters
//! (stored as byte-array) to a runtime module.
//!
//! For more detail, please see the example `resnet` in `examples` repository.
use std::os::raw::c_char;
use tvm_common::ffi;
/// A struct holding TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello".to_vec();
/// let barr = TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
/// ```
#[derive(Debug, Clone)]
pub struct TVMByteArray {
pub(crate) inner: ffi::TVMByteArray,
}
impl TVMByteArray {
pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray {
TVMByteArray { inner: barr }
}
/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.inner.size
}
/// Gets the underlying byte-array as `Vec<i8>`
pub fn data(&self) -> Vec<i8> {
unsafe {
let sz = self.len();
let mut ret_buf = Vec::with_capacity(sz);
ret_buf.set_len(sz);
self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz);
ret_buf
}
}
}
impl<'a, T: AsRef<[u8]>> From<T> for TVMByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
let barr = ffi::TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
};
TVMByteArray::new(barr)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), vec![1i8, 2, 3]);
let v = b"hello".to_vec();
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]);
}
}
......@@ -47,7 +47,7 @@ use failure::Error;
use tvm_common::ffi;
use crate::function;
use crate::{function, TVMArgValue};
/// Device type can be from a supported device name. See the supported devices
/// in [TVM](https://github.com/dmlc/tvm).
......@@ -60,7 +60,7 @@ use crate::function;
///```
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TVMDeviceType(pub usize);
pub struct TVMDeviceType(pub i64);
impl Default for TVMDeviceType {
/// default device is cpu.
......@@ -141,6 +141,12 @@ impl<'a> From<&'a str> for TVMDeviceType {
}
}
impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
fn from(dev: &TVMDeviceType) -> Self {
Self::Int(dev.0)
}
}
/// Represents the underlying device context. Default is cpu.
///
/// ## Examples
......
......@@ -30,7 +30,7 @@
//!
//! Checkout the `examples` repository for more details.
#![feature(box_syntax)]
#![feature(box_syntax, type_alias_enum_variants)]
#[macro_use]
extern crate failure;
......@@ -48,7 +48,6 @@ use std::{
use failure::Error;
pub use crate::{
bytearray::TVMByteArray,
context::{TVMContext, TVMDeviceType},
errors::*,
function::Function,
......@@ -56,7 +55,7 @@ pub use crate::{
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMType},
ffi::{self, TVMByteArray, TVMType},
packed_func::{TVMArgValue, TVMRetValue},
},
};
......@@ -89,7 +88,6 @@ pub(crate) fn set_last_error(err: &Error) {
#[macro_use]
pub mod function;
pub mod bytearray;
pub mod context;
pub mod errors;
pub mod module;
......
......@@ -76,3 +76,7 @@ cargo run --bin float
cargo run --bin array
cargo run --bin string
cd -
cd examples/resnet
cargo build
cd -
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment