Commit acb9fd62 by thefiddler Committed by Tianqi Chen

Add support for keras upsampling (#305)

* Add support for keras upsampling

* Fix code formatting

* Fix indentation

* Fix indentation round 2

* Hide unused parameter

* Only enable UpSampling2D since the others are not currently supported by TVM

* Improve error messages and code layout
parent 6a0fb6ef
...@@ -268,6 +268,31 @@ def _convert_pooling(insym, keras_layer, symtab): ...@@ -268,6 +268,31 @@ def _convert_pooling(insym, keras_layer, symtab):
raise TypeError("Unsupported pooling type : {}".format(keras_layer)) raise TypeError("Unsupported pooling type : {}".format(keras_layer))
def _convert_upsample(insym, keras_layer, _):
_check_data_format(keras_layer)
upsample_type = type(keras_layer).__name__
if upsample_type == "UpSampling1D":
h = keras_layer.size
params = {'scale': h}
elif upsample_type == "UpSampling2D":
h, w = keras_layer.size
if h != w:
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
params = {'scale': h}
elif upsample_type == "UpSampling3D":
h, w, d = keras_layer.size
if h != w or w != d:
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
params = {'scale': h}
else:
raise TypeError("Unsupported upsampling type : {}".format(upsample_type))
return _sym.upsampling(insym, **params)
def _convert_batchnorm(insym, keras_layer, symtab): def _convert_batchnorm(insym, keras_layer, symtab):
params = {'scale': False, params = {'scale': False,
'center': False, 'center': False,
...@@ -358,6 +383,7 @@ _convert_map = { ...@@ -358,6 +383,7 @@ _convert_map = {
'Subtract' : _convert_merge, 'Subtract' : _convert_merge,
'Multiply' : _convert_merge, 'Multiply' : _convert_merge,
'ZeroPadding2D' : _convert_padding, 'ZeroPadding2D' : _convert_padding,
'UpSampling2D' : _convert_upsample,
# 'ZeroPadding1D' : _convert_padding, # 'ZeroPadding1D' : _convert_padding,
# 'AveragePooling1D' : _convert_pooling, # 'AveragePooling1D' : _convert_pooling,
...@@ -367,7 +393,7 @@ _convert_map = { ...@@ -367,7 +393,7 @@ _convert_map = {
# 'Cropping1D' : _convert_cropping, # 'Cropping1D' : _convert_cropping,
# 'Cropping2D' : _convert_cropping, # 'Cropping2D' : _convert_cropping,
# 'UpSampling1D' : _convert_upsample, # 'UpSampling1D' : _convert_upsample,
# 'UpSampling2D' : _convert_upsample, # 'UpSampling3D' : _convert_upsample,
# 'Conv1D' : _convert_convolution1d, # 'Conv1D' : _convert_convolution1d,
# 'GRU' : _convert_gru, # 'GRU' : _convert_gru,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment