/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #[macro_use] extern crate ndarray; extern crate serde; extern crate serde_json; extern crate tvm_runtime; use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; use ndarray::Array; use tvm_runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; const BATCH_SIZE: usize = 4; const IN_DIM: usize = 8; macro_rules! check_sum { ($e:expr, $a:ident, $b:ident) => { let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); check_sum!(a, $b); }; ($e:expr, $a:expr, $b:ident) => { let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); check_sum!(a, $b); }; ($a:ident, $b:ident) => { let a_sum: f32 = $a.scalar_sum(); let b_sum: f32 = $b.scalar_sum(); assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); }; } fn main() { let syslib = SystemLibModule::default(); let mut params_bytes = Vec::new(); fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) .unwrap() .read_to_end(&mut params_bytes) .unwrap(); let params = tvm_runtime::load_param_dict(¶ms_bytes) .unwrap() .into_iter() .map(|(k, v)| (k, v.to_owned())) .collect::<HashMap<String, Tensor<'static>>>(); let graph = Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()) .unwrap(); let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); let x = Array::from_shape_vec( (BATCH_SIZE, IN_DIM), (0..BATCH_SIZE * IN_DIM) .map(|x| x as f32) .collect::<Vec<f32>>(), ) .unwrap(); let w = Array::try_from(params.get("dense0_weight").unwrap()) .unwrap() .into_shape((IN_DIM * 2, IN_DIM)) .unwrap(); let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); let dense = x.dot(&w.t()) + &b; let left = dense.slice(s![.., 0..IN_DIM]); let right = dense.slice(s![.., IN_DIM..]); let expected_o0 = &left + 1f32; let expected_o1 = &right - 1f32; exec.load_params(params); exec.set_input("data", (&x).into()); check_sum!(exec, data, x); check_sum!(exec, dense0_weight, w); check_sum!(exec, dense0_bias, b); exec.run(); check_sum!(exec, 0, expected_o0); check_sum!(exec, 1, expected_o1); check_sum!(exec, 2, dense); }