test_graph_serde.rs 1.01 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
#![feature(try_from)]

extern crate serde;
extern crate serde_json;

extern crate tvm_runtime;

use std::{convert::TryFrom, fs, io::Read};

use tvm_runtime::Graph;

#[test]
fn test_load_graph() {
    let mut params_bytes = Vec::new();
    fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
        .expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
        .read_to_end(&mut params_bytes)
        .unwrap();
    let _params = tvm_runtime::load_param_dict(&params_bytes);

    let graph = Graph::try_from(
        &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
    )
    .unwrap();

    assert_eq!(graph.nodes[3].op, "tvm_op");
    assert_eq!(
        graph.nodes[3]
            .attrs
            .as_ref()
            .unwrap()
            .get("func_name")
            .unwrap(),
        "fuse_dense"
    );
    assert_eq!(graph.nodes[5].inputs[0].index, 0);
    assert_eq!(graph.nodes[6].inputs[0].index, 1);
    assert_eq!(graph.heads.len(), 2);
}