test_graph_pass.py 482 Bytes
Newer Older
1 2 3
"""Unittest cases for graph pass"""
import nnvm
import nnvm.compiler
4
from nnvm import symbol as sym
5
from nnvm.compiler import graph_util, graph_attr
6 7

def test_infer_attr():
8
    x = sym.Variable("x")
9 10
    y = x * 2
    g = nnvm.graph.create(y)
11
    ishape, oshape = graph_util.infer_shape(g, x=(10,20))
12 13
    assert tuple(oshape[0]) == (10, 20)

14
    itype, otype = graph_util.infer_dtype(g, x="float32")
15 16 17 18
    assert otype[0] == "float32"

if __name__ == "__main__":
    test_infer_attr()