Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
fdf795a0
Commit
fdf795a0
authored
Sep 20, 2018
by
Siva
Committed by
Tianqi Chen
Sep 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND][TENSORFLOW] GPU support for tensorflow models. (#1718)
parent
ae5a28db
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
39 deletions
+97
-39
nnvm/python/nnvm/frontend/tensorflow.py
+64
-26
nnvm/tests/python/frontend/tensorflow/test_forward.py
+20
-8
tutorials/nnvm/from_tensorflow.py
+13
-5
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
fdf795a0
...
@@ -35,6 +35,7 @@ class AttrCvt(object):
...
@@ -35,6 +35,7 @@ class AttrCvt(object):
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# Retain the names
# Retain the names
try
:
try
:
attrs
[
'name'
]
=
attrs
[
'_node_name'
]
attrs
[
'name'
]
=
attrs
[
'_node_name'
]
...
@@ -121,6 +122,9 @@ def _pooling(name):
...
@@ -121,6 +122,9 @@ def _pooling(name):
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
flip_layout
=
False
input_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]][
0
]
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
attr
[
'kernel_shape'
]
=
(
attr
[
'ksize'
][
1
],
attr
[
'ksize'
][
2
])
attr
[
'kernel_shape'
]
=
(
attr
[
'ksize'
][
1
],
attr
[
'ksize'
][
2
])
...
@@ -129,11 +133,17 @@ def _pooling(name):
...
@@ -129,11 +133,17 @@ def _pooling(name):
else
:
else
:
raise
TypeError
(
"Unsupported data_format type : {}"
.
format
(
attr
[
'data_format'
]))
raise
TypeError
(
"Unsupported data_format type : {}"
.
format
(
attr
[
'data_format'
]))
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
tmp_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]][
0
]
input_shape
=
[
tmp_shape
[
ii
]
for
ii
in
(
0
,
3
,
1
,
2
)]
inputs
[
0
]
=
_sym
.
transpose
(
inputs
[
0
],
axes
=
(
0
,
3
,
1
,
2
))
attr
[
'data_format'
]
=
"NCHW"
flip_layout
=
True
# Fix strides
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix padding
# Fix padding
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
if
attr
[
'padding'
]
==
'VALID'
:
if
attr
[
'padding'
]
==
'VALID'
:
...
@@ -142,11 +152,11 @@ def _pooling(name):
...
@@ -142,11 +152,11 @@ def _pooling(name):
stride_h
,
stride_w
=
attr
[
'strides'
]
stride_h
,
stride_w
=
attr
[
'strides'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
in_h
=
input_shape
s
[
0
]
[
1
]
in_h
=
input_shape
[
1
]
in_w
=
input_shape
s
[
0
]
[
2
]
in_w
=
input_shape
[
2
]
else
:
else
:
in_h
=
input_shape
s
[
0
]
[
2
]
in_h
=
input_shape
[
2
]
in_w
=
input_shape
s
[
0
]
[
3
]
in_w
=
input_shape
[
3
]
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
...
@@ -158,7 +168,7 @@ def _pooling(name):
...
@@ -158,7 +168,7 @@ def _pooling(name):
if
name
==
"avg_pool"
:
if
name
==
"avg_pool"
:
attr
[
'count_include_pad'
]
=
False
attr
[
'count_include_pad'
]
=
False
return
AttrCvt
(
out
=
AttrCvt
(
op_name
=
_dimension_picker
(
name
),
op_name
=
_dimension_picker
(
name
),
transforms
=
{
transforms
=
{
'kernel_shape'
:
'pool_size'
,
'kernel_shape'
:
'pool_size'
,
...
@@ -166,33 +176,53 @@ def _pooling(name):
...
@@ -166,33 +176,53 @@ def _pooling(name):
ignores
=
[
'ksize'
],
ignores
=
[
'ksize'
],
extras
=
{
'ceil_mode'
:
False
},
extras
=
{
'ceil_mode'
:
False
},
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
if
flip_layout
:
out
=
_sym
.
transpose
(
out
,
axes
=
(
0
,
2
,
3
,
1
))
return
out
return
_impl
return
_impl
def
_conv
(
opname
):
def
_conv
(
opname
):
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
flip_layout
=
False
# Extract kernel shape from params
input_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]][
0
]
conv_param_weights
=
params
[
inputs
[
1
]
.
list_output_names
()[
0
]]
weights_shape
=
params
[
inputs
[
1
]
.
list_output_names
()[
0
]]
.
shape
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
input_shape
=
[
input_shape
[
ii
]
for
ii
in
(
0
,
3
,
1
,
2
)]
inputs
[
0
]
=
_sym
.
transpose
(
inputs
[
0
],
axes
=
(
0
,
3
,
1
,
2
))
if
opname
==
'conv'
:
weights_shape
=
[
weights_shape
[
ii
]
for
ii
in
(
3
,
2
,
0
,
1
)]
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
3
,
2
,
0
,
1
))
else
:
weights_shape
=
[
weights_shape
[
ii
]
for
ii
in
(
2
,
3
,
0
,
1
)]
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
2
,
3
,
0
,
1
))
attr
[
'data_format'
]
=
"NCHW"
flip_layout
=
True
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
kernel_h
,
kernel_w
,
_
,
depth_mult
=
conv_param_weights
.
shape
kernel_h
,
kernel_w
,
_
,
depth_mult
=
weights_
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
attr
[
'kernel_shape'
]
=
(
weights_shape
[
0
],
weights_
shape
[
1
])
if
opname
==
'conv'
:
if
opname
==
'conv'
:
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
3
]
attr
[
'channels'
]
=
weights_
shape
[
3
]
else
:
else
:
attr
[
'channels'
]
=
input_shape
s
[
0
]
[
3
]
*
depth_mult
attr
[
'channels'
]
=
input_shape
[
3
]
*
depth_mult
if
'dilations'
in
attr
:
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
elif
attr
[
'data_format'
]
==
'NCHW'
:
depth_mult
,
_
,
kernel_h
,
kernel_w
=
conv_param_weights
.
shape
depth_mult
,
_
,
kernel_h
,
kernel_w
=
weights_
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
attr
[
'kernel_shape'
]
=
(
weights_shape
[
2
],
weights_
shape
[
3
])
if
opname
==
'conv'
:
if
opname
==
'conv'
:
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
1
]
attr
[
'channels'
]
=
weights_shape
[
0
]
else
:
else
:
attr
[
'channels'
]
=
input_shapes
[
0
][
1
]
*
depth_mult
attr
[
'channels'
]
=
input_shape
[
0
]
*
depth_mult
if
attr
[
'channels'
]
<
0
:
attr
[
'channels'
]
*=
-
1
if
'dilations'
in
attr
:
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
...
@@ -215,11 +245,11 @@ def _conv(opname):
...
@@ -215,11 +245,11 @@ def _conv(opname):
stride_h
,
stride_w
=
attr
[
'strides'
]
stride_h
,
stride_w
=
attr
[
'strides'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
in_h
=
input_shape
s
[
0
]
[
1
]
in_h
=
input_shape
[
1
]
in_w
=
input_shape
s
[
0
]
[
2
]
in_w
=
input_shape
[
2
]
else
:
else
:
in_h
=
input_shape
s
[
0
]
[
2
]
in_h
=
input_shape
[
2
]
in_w
=
input_shape
s
[
0
]
[
3
]
in_w
=
input_shape
[
3
]
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
...
@@ -248,7 +278,7 @@ def _conv(opname):
...
@@ -248,7 +278,7 @@ def _conv(opname):
else
:
else
:
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
return
AttrCvt
(
out
=
AttrCvt
(
op_name
=
_dimension_picker
(
'conv'
),
op_name
=
_dimension_picker
(
'conv'
),
transforms
=
{
transforms
=
{
'kernel_shape'
:
'kernel_size'
,
'kernel_shape'
:
'kernel_size'
,
...
@@ -257,6 +287,11 @@ def _conv(opname):
...
@@ -257,6 +287,11 @@ def _conv(opname):
'group'
:
(
'groups'
,
1
)},
'group'
:
(
'groups'
,
1
)},
extras
=
{
'use_bias'
:
len
(
inputs
)
==
3
},
extras
=
{
'use_bias'
:
len
(
inputs
)
==
3
},
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
if
flip_layout
:
out
=
_sym
.
transpose
(
out
,
axes
=
(
0
,
2
,
3
,
1
))
return
out
return
_impl
return
_impl
def
_decode_image
():
def
_decode_image
():
...
@@ -305,7 +340,7 @@ def _matmul():
...
@@ -305,7 +340,7 @@ def _matmul():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
channels
=
_infer_channels
(
inputs
[
1
],
params
,
not
attr
[
'transpose_b'
])
channels
=
_infer_channels
(
inputs
[
1
],
params
,
not
attr
[
'transpose_b'
])
if
attr
[
'transpose_a'
]:
if
attr
[
'transpose_a'
]:
inputs
[
0
]
=
_sym
.
transpose
(
inputs
[
0
],
ax
i
s
(
1
,
0
))
inputs
[
0
]
=
_sym
.
transpose
(
inputs
[
0
],
ax
e
s
(
1
,
0
))
if
not
attr
[
'transpose_b'
]:
if
not
attr
[
'transpose_b'
]:
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
1
,
0
))
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
1
,
0
))
return
AttrCvt
(
op_name
=
"dense"
,
return
AttrCvt
(
op_name
=
"dense"
,
...
@@ -948,7 +983,7 @@ class GraphProto(object):
...
@@ -948,7 +983,7 @@ class GraphProto(object):
self
.
_num_param
=
0
self
.
_num_param
=
0
self
.
_num_rnn_layer
=
False
self
.
_num_rnn_layer
=
False
def
from_tensorflow
(
self
,
graph
):
def
from_tensorflow
(
self
,
graph
,
layout
=
"NHWC"
):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
Follow the tensorflow graph definition to parse and convert it to NNVM.
...
@@ -1036,6 +1071,9 @@ class GraphProto(object):
...
@@ -1036,6 +1071,9 @@ class GraphProto(object):
# Pass the node name too in attr
# Pass the node name too in attr
attr
[
"_node_name"
]
=
node
.
name
attr
[
"_node_name"
]
=
node
.
name
# Pass the target layout
attr
[
"_target_layout"
]
=
layout
#ToDo: Some of the tensorflow operators internaly maintain
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
...
@@ -1265,7 +1303,7 @@ class GraphProto(object):
...
@@ -1265,7 +1303,7 @@ class GraphProto(object):
return
inputs
return
inputs
def
from_tensorflow
(
graph
):
def
from_tensorflow
(
graph
,
layout
=
"NHWC"
):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
The companion parameters will be handled automatically.
...
@@ -1283,5 +1321,5 @@ def from_tensorflow(graph):
...
@@ -1283,5 +1321,5 @@ def from_tensorflow(graph):
Dict of converted parameters stored in tvm.ndarray format
Dict of converted parameters stored in tvm.ndarray format
"""
"""
g
=
GraphProto
()
g
=
GraphProto
()
sym
,
params
=
g
.
from_tensorflow
(
graph
)
sym
,
params
=
g
.
from_tensorflow
(
graph
,
layout
)
return
sym
,
params
return
sym
,
params
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
fdf795a0
...
@@ -26,11 +26,15 @@ import nnvm.testing.tf
...
@@ -26,11 +26,15 @@ import nnvm.testing.tf
#######################################################################
#######################################################################
# Generic run functions for TVM & tensorflow
# Generic run functions for TVM & tensorflow
# ------------------------------------------
# ------------------------------------------
def
run_tvm_graph
(
graph_def
,
input_data
,
input_node
,
output_shape
,
output_dtype
):
def
run_tvm_graph
(
graph_def
,
input_data
,
input_node
,
output_shape
,
output_dtype
,
target
=
'llvm'
):
""" Generic function to compile on nnvm and execute on tvm """
""" Generic function to compile on nnvm and execute on tvm """
sym
,
params
=
nnvm
.
frontend
.
from_tensorflow
(
graph_def
)
layout
=
None
target
=
'llvm'
if
target
==
"cuda"
:
layout
=
"NCHW"
sym
,
params
=
nnvm
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
)
target_host
=
'llvm'
if
isinstance
(
input_data
,
list
):
if
isinstance
(
input_data
,
list
):
shape_dict
=
{}
shape_dict
=
{}
dtype_dict
=
{}
dtype_dict
=
{}
...
@@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
...
@@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
shape_dict
=
{
input_node
:
input_data
.
shape
}
shape_dict
=
{
input_node
:
input_data
.
shape
}
dtype_dict
=
{
input_node
:
input_data
.
dtype
}
dtype_dict
=
{
input_node
:
input_data
.
dtype
}
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
target
,
shape_dict
,
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
target
=
target
,
target_host
=
target_host
,
shape
=
shape_dict
,
dtype
=
dtype_dict
,
params
=
params
)
dtype
=
dtype_dict
,
params
=
params
)
ctx
=
tvm
.
c
pu
(
0
)
ctx
=
tvm
.
c
ontext
(
target
,
0
)
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set inputs
# set inputs
...
@@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
...
@@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
)
)
tf_output
=
run_tf_graph
(
sess
,
in_data
,
in_name
,
out_name
)
tf_output
=
run_tf_graph
(
sess
,
in_data
,
in_name
,
out_name
)
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
tf_output
.
shape
,
tf_output
.
dtype
)
for
device
in
[
"llvm"
,
"cuda"
]:
np
.
testing
.
assert_allclose
(
tf_output
,
tvm_output
,
atol
=
1e-5
,
rtol
=
1e-5
)
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
continue
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
tf_output
.
shape
,
tf_output
.
dtype
,
target
=
device
)
np
.
testing
.
assert_allclose
(
tf_output
,
tvm_output
,
atol
=
1e-5
,
rtol
=
1e-5
)
sess
.
close
()
sess
.
close
()
#######################################################################
#######################################################################
...
...
tutorials/nnvm/from_tensorflow.py
View file @
fdf795a0
...
@@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto)
...
@@ -50,6 +50,16 @@ map_proto_url = os.path.join(repo_base, map_proto)
lable_map
=
'imagenet_synset_to_human_label_map.txt'
lable_map
=
'imagenet_synset_to_human_label_map.txt'
lable_map_url
=
os
.
path
.
join
(
repo_base
,
lable_map
)
lable_map_url
=
os
.
path
.
join
(
repo_base
,
lable_map
)
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target
=
'llvm'
target_host
=
'llvm'
layout
=
None
ctx
=
tvm
.
cpu
(
0
)
######################################################################
######################################################################
# Download required files
# Download required files
...
@@ -99,7 +109,7 @@ x = np.array(image)
...
@@ -99,7 +109,7 @@ x = np.array(image)
# Results:
# Results:
# sym: nnvm graph for given tensorflow protobuf.
# sym: nnvm graph for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf).
# params: params converted from tensorflow params (tensor protobuf).
sym
,
params
=
nnvm
.
frontend
.
from_tensorflow
(
graph_def
)
sym
,
params
=
nnvm
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
)
print
(
"Tensorflow protobuf imported as nnvm graph"
)
print
(
"Tensorflow protobuf imported as nnvm graph"
)
######################################################################
######################################################################
...
@@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph")
...
@@ -113,18 +123,16 @@ print ("Tensorflow protobuf imported as nnvm graph")
# lib: target library which can be deployed on target with tvm runtime.
# lib: target library which can be deployed on target with tvm runtime.
import
nnvm.compiler
import
nnvm.compiler
target
=
'llvm'
shape_dict
=
{
'DecodeJpeg/contents'
:
x
.
shape
}
shape_dict
=
{
'DecodeJpeg/contents'
:
x
.
shape
}
dtype_dict
=
{
'DecodeJpeg/contents'
:
'uint8'
}
dtype_dict
=
{
'DecodeJpeg/contents'
:
'uint8'
}
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
target
,
shape_dic
t
,
dtype
=
dtype_dict
,
params
=
params
)
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
shape
=
shape_dict
,
target
=
target
,
target_host
=
target_hos
t
,
dtype
=
dtype_dict
,
params
=
params
)
######################################################################
######################################################################
# Execute the portable graph on TVM
# Execute the portable graph on TVM
# ---------------------------------
# ---------------------------------
# Now we can try deploying the NNVM compiled model on
cpu
target.
# Now we can try deploying the NNVM compiled model on target.
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
ctx
=
tvm
.
cpu
(
0
)
dtype
=
'uint8'
dtype
=
'uint8'
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set inputs
# set inputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment