Commit 92a00489 by SasakiSaki Committed by Tianqi Chen

[Relay] Improve more operator mxnet frontend importer (#2772)

parent 2919a3ee
...@@ -298,6 +298,51 @@ def _mx_leaky_relu(inputs, attrs): ...@@ -298,6 +298,51 @@ def _mx_leaky_relu(inputs, attrs):
raise RuntimeError("act_type: {} is not supported".format(act_type)) raise RuntimeError("act_type: {} is not supported".format(act_type))
def _mx_make_power(power):
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
scalar = _expr.const(power, dtype=None)
# Note: int maps to "int32", float maps to "float32"
return _op.power(inputs[0], scalar)
return _impl
def _mx_make_exponent(base):
# exp(b, x) = e^b * e^x
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
scalar = _op.exp(_expr.const(base, dtype="float32"))
return _op.multiply(inputs[0], scalar)
return _impl
def _mx_make_logarithm(base):
# log(b, x) = log(x) / log(b)
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
scalar = _op.log(_expr.const(base, dtype="float32"))
return _op.divide(inputs[0], scalar)
return _impl
def _mx_expm1():
# exp_minus_1 x = exp(x) - 1
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
one = _expr.const(1, dtype="float32")
return _op.log(_op.subtract(inputs[0], one))
return _impl
def _mx_log1p():
# 1_plus_log x = log(x + 1)
def _impl(inputs, _): # Note: no attrs
assert len(inputs) == 1
one = _expr.const(1, dtype="float32")
return _op.log(_op.add(inputs[0], one))
return _impl
def _mx_lrn(inputs, attrs): def _mx_lrn(inputs, attrs):
new_attrs = {} new_attrs = {}
new_attrs["alpha"] = attrs.get_float("alpha", 0.0001) new_attrs["alpha"] = attrs.get_float("alpha", 0.0001)
...@@ -450,7 +495,6 @@ _identity_list = [ ...@@ -450,7 +495,6 @@ _identity_list = [
"exp", "exp",
"sigmoid", "sigmoid",
"tanh", "tanh",
"exp",
"negative", "negative",
"reshape_like", "reshape_like",
"zeros_like", "zeros_like",
...@@ -482,6 +526,20 @@ _convert_map = { ...@@ -482,6 +526,20 @@ _convert_map = {
"_minimum" : _rename(_op.minimum), "_minimum" : _rename(_op.minimum),
"flatten" : _rename(_op.nn.batch_flatten), "flatten" : _rename(_op.nn.batch_flatten),
"Flatten" : _rename(_op.nn.batch_flatten), "Flatten" : _rename(_op.nn.batch_flatten),
# scalar power
"square" : _mx_make_power(2),
"sqrt" : _mx_make_power(1/2),
"rsqrt" : _mx_make_power(-1/2),
"cbrt" : _mx_make_power(1/3),
"rcbrt" : _mx_make_power(-1/3),
"__pow_scalar__" : _binop_scalar(_op.power),
"_power_scalar" : _binop_scalar(_op.power),
"__rsub_scalar__" : _rbinop_scalar(_op.subtract),
"_rminus_scalar" : _rbinop_scalar(_op.subtract),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rpow_scalar__" : _rbinop_scalar(_op.power),
# scalar op
"__add_scalar__" : _binop_scalar(_op.add), "__add_scalar__" : _binop_scalar(_op.add),
"_plus_scalar" : _binop_scalar(_op.add), "_plus_scalar" : _binop_scalar(_op.add),
"__sub_scalar__" : _binop_scalar(_op.subtract), "__sub_scalar__" : _binop_scalar(_op.subtract),
...@@ -490,13 +548,10 @@ _convert_map = { ...@@ -490,13 +548,10 @@ _convert_map = {
"_mul_scalar" : _binop_scalar(_op.multiply), "_mul_scalar" : _binop_scalar(_op.multiply),
"__div_scalar__" : _binop_scalar(_op.divide), "__div_scalar__" : _binop_scalar(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide), "_div_scalar" : _binop_scalar(_op.divide),
"__pow_scalar__" : _binop_scalar(_op.power), "log2" : _mx_make_logarithm(2),
"_power_scalar" : _binop_scalar(_op.power), "log10" : _mx_make_logarithm(10),
"__rsub_scalar__" : _rbinop_scalar(_op.subtract), "log1p" : _mx_log1p,
"_rminus_scalar" : _rbinop_scalar(_op.subtract), "expm1" : _mx_expm1,
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rpow_scalar__" : _rbinop_scalar(_op.power),
"_equal_scalar" : _mx_compare(_op.equal, _binop_scalar), "_equal_scalar" : _mx_compare(_op.equal, _binop_scalar),
"_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar), "_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar),
"_greater_scalar" : _mx_compare(_op.greater, _binop_scalar), "_greater_scalar" : _mx_compare(_op.greater, _binop_scalar),
...@@ -506,6 +561,7 @@ _convert_map = { ...@@ -506,6 +561,7 @@ _convert_map = {
"_maximum_scalar" : _binop_scalar(_op.maximum), "_maximum_scalar" : _binop_scalar(_op.maximum),
"_minimum_scalar" : _binop_scalar(_op.minimum), "_minimum_scalar" : _binop_scalar(_op.minimum),
# reduction ops # reduction ops
"mean" : _reduce(_op.mean),
"max" : _reduce(_op.max), "max" : _reduce(_op.max),
"min" : _reduce(_op.min), "min" : _reduce(_op.min),
"sum" : _reduce(_op.sum), "sum" : _reduce(_op.sum),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment