Commit bfa966a8 by Alexander Pivovarov Committed by Tianqi Chen

Fix Error messages in tflite.py (#3320)

parent 45ef90c0
...@@ -180,7 +180,6 @@ def _convert_convolution(insym, keras_layer, symtab): ...@@ -180,7 +180,6 @@ def _convert_convolution(insym, keras_layer, symtab):
else: else:
kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape
weight = weightList[0].transpose([3, 2, 0, 1]) weight = weightList[0].transpose([3, 2, 0, 1])
dilation = [1, 1]
if isinstance(keras_layer.dilation_rate, (list, tuple)): if isinstance(keras_layer.dilation_rate, (list, tuple)):
dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
else: else:
......
...@@ -203,7 +203,6 @@ def _convert_convolution(inexpr, keras_layer, etab): ...@@ -203,7 +203,6 @@ def _convert_convolution(inexpr, keras_layer, etab):
else: else:
kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape
weight = weightList[0].transpose([3, 2, 0, 1]) weight = weightList[0].transpose([3, 2, 0, 1])
dilation = [1, 1]
if isinstance(keras_layer.dilation_rate, (list, tuple)): if isinstance(keras_layer.dilation_rate, (list, tuple)):
dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
else: else:
......
...@@ -156,7 +156,7 @@ class OperatorConverter(object): ...@@ -156,7 +156,7 @@ class OperatorConverter(object):
if tensor_wrapper.tensor.Type() == TensorType.INT32: if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy()) tensor_wrapper.tensor.ShapeAsNumpy())
raise NotImplementedError("Not support tensor type {}" raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type()))) .format(str(tensor_wrapper.tensor.Type())))
def get_tensor_type_str(self, tensor_type): def get_tensor_type_str(self, tensor_type):
...@@ -172,7 +172,8 @@ class OperatorConverter(object): ...@@ -172,7 +172,8 @@ class OperatorConverter(object):
return "float32" return "float32"
if tensor_type == TensorType.INT32: if tensor_type == TensorType.INT32:
return "int32" return "int32"
raise NotImplementedError("Not support tensor type {}".format(str(tensor_type))) raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_type)))
def convert_conv2d(self, op): def convert_conv2d(self, op):
"""Convert TFLite conv2d""" """Convert TFLite conv2d"""
...@@ -450,8 +451,8 @@ class OperatorConverter(object): ...@@ -450,8 +451,8 @@ class OperatorConverter(object):
conv_options = DepthwiseConv2DOptions() conv_options = DepthwiseConv2DOptions()
conv_options.Init(op_options.Bytes, op_options.Pos) conv_options.Init(op_options.Bytes, op_options.Pos)
depth_multiplier = conv_options.DepthMultiplier() depth_multiplier = conv_options.DepthMultiplier()
assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \
"no matter original value be set by 0.25, 0.5 or any else" "original value is set to 0.25, 0.5 or anything else"
else: else:
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(conv_type)) 'Operator {} is not supported for frontend TFLite.'.format(conv_type))
......
...@@ -21,7 +21,7 @@ from tvm.contrib import graph_runtime ...@@ -21,7 +21,7 @@ from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
import keras import keras
# prevent keras from using up all gpu memory # prevent Keras from using up all gpu memory
import tensorflow as tf import tensorflow as tf
from keras.backend.tensorflow_backend import set_session from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto() config = tf.ConfigProto()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment