Commit a82656d4 by Pariksheet Pinjari Committed by Tianqi Chen

Support imagescalar operator for onnx (#448)

parent 435201ee
# pylint: disable=import-self, invalid-name, unused-argument
"""ONNX: Open Neural Network Exchange frontend."""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from ..compiler import graph_util
from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
from .common import get_nnvm_op, Renamer, SymbolTable, AttrConverter as AttrCvt
__all__ = ['from_onnx']
......@@ -322,6 +323,18 @@ class ThresholdedRelu(OnnxOpConverter):
return _sym.relu(inputs[0] - alpha)
class ImageScaler(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
channelScale = attr['scale']
bias_attr = attr['bias']
bias = SymbolTable().new_const(np.array(bias_attr).reshape([3, 1, 1]))
scaledChannel = _sym.__mul_scalar__(inputs[0], scalar=channelScale)
ret = _sym.broadcast_add(scaledChannel, bias)
return ret
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
......@@ -410,7 +423,7 @@ def _get_convert_map(opset):
'Scale': Scale.get_converter(opset),
# 'GRUUnit'
# 'ATen'
# 'ImageScaler'
'ImageScaler': ImageScaler.get_converter(opset),
# 'MeanVarianceNormalization'
# 'Crop'
# 'Embedding'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment