Commit 695647db by MORITA Kazutaka Committed by Jared Roesch

Improve NNVM to Relay conversion (#2734)

* Improve NNVM to Relay conversion

* fix pylint

* support __lshift_scalar__, abs, ceil, floor, and trunc to pass CI
parent 52d5cf89
......@@ -8,10 +8,12 @@ import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads
from tvm import relay
import nnvm
from nnvm.compiler import graph_util
from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE
from nnvm.to_relay import to_relay
from .config import ctx_list
def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
......@@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
debug_stage = "running"
nnvm_res = main_function(**np_inputs)
try:
logging.debug("checking to_relay conversion")
inputs = np_inputs_without_head_grads.copy()
func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**inputs)
m.set_input(**params)
m.run()
for i in range(out_len):
relay_out = m.get_output(i).asnumpy()
tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
except NotImplementedError as err:
# the NNVM operator is not supported yet
logging.warning(err)
if backward_graph is not None:
grad_var_names = [x.attr('name') for x in grad_input_vars]
nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}
......
......@@ -6,7 +6,8 @@ import numpy
from tvm import relay, nd
from tvm.relay import op, expr, var
from tvm.relay.frontend.common import StrAttrsDict
from tvm.relay.frontend.nnvm_common import _rename
from tvm.relay.frontend.nnvm_common import _rename, _binop_scalar, _rbinop_scalar, \
_elemwise_sum, _softmax_op, _compare, _reduce
from .symbol import Symbol
from .compiler import graph_attr
from .graph import create as graph_create
......@@ -25,11 +26,6 @@ def _dense(children, attrs, odtype='float32'):
else:
return dense
def _nn_softmax(children, attrs, odtype='float32'):
assert len(children) == 1
axis = attrs.get_int('axis', 1)
return op.nn.softmax(children[0], axis)
def _conv2d(children, attrs, odtype='float32'):
use_bias = attrs.get_bool('use_bias', True)
......@@ -150,84 +146,6 @@ def _transpose(children, attrs, odtype='float32'):
return op.transpose(children[0], axes=axes)
def _add(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.add(left, right)
def _subtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.subtract(left, right)
def _rsubtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.subtract(right, left)
def _multiply(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.multiply(left, right)
def _divide(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.divide(left, right)
def _rshift(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype='int32')
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.right_shift(left, right)
def _clip(children, attrs, odtype='float32'):
a_min = attrs.get_float('a_min')
a_max = attrs.get_float('a_max')
......@@ -255,9 +173,6 @@ def broadcast_to(children, attrs, odtype='float32'):
rconst = relay.Constant(nd.array(array))
return op.broadcast_to_like(data, rconst)
def _copy(children, attrs, odtype='float32'):
return op.copy(children[0])
def _global_avg_pool2d(children, attrs, odtype='float32'):
data = children[0]
......@@ -309,42 +224,10 @@ def _full_like(children, attrs, odtype='float32'):
return op.full_like(children[0], fill_value)
def _greater(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type')
if out_type:
return op.greater(children[0], children[1]).astype(out_type)
else:
return op.greater(children[0], children[1])
def _greater_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.greater_equal(children[0], children[1]).astype(out_type)
else:
return op.greater_equal(children[0], children[1])
def _less(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less(children[0], children[1]).astype(out_type)
else:
return op.less(children[0], children[1])
def _less_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less_equal(children[0], children[1]).astype(out_type)
else:
return op.less_equal(children[0], children[1])
def _strided_slice(children, attrs, odtype='float32'):
begin = attrs.get_int_list('begin')
end = attrs.get_int_list('end')
strides = attrs.get_int_list('strides', None)
strides = attrs.get_int_list('stride', None)
return op.strided_slice(children[0], begin, end, strides=strides)
......@@ -358,14 +241,11 @@ def _split(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 0)
return op.split(children[0], indices_or_sections, axis)
return op.split(children[0], indices_or_sections, axis).astuple()
def _squeeze(children, attrs, odtype='float32'):
axis = None
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
axis = attrs.get_int_tuple('axis', None)
axis = [axis] if isinstance(axis, int) else axis
return op.squeeze(children[0], axis)
......@@ -378,20 +258,60 @@ def _dropout(children, attrs, odtype='float32'):
return op.nn.dropout(children[0], rate)
def _mean(children, attrs, odtype='float32'):
axis = None
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
axis = attrs.get_int_tuple('axis', None)
keepdims = attrs.get_bool('keepdims')
return op.mean(children[0], axis, keepdims)
def _prelu(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 1)
return op.nn.prelu(children[0], children[1], axis)
def _lrn(children, attrs, odtype='float32'):
size = attrs.get_int("size", 5)
axis = attrs.get_int("axis", 1)
bias = attrs.get_float("bias", 2)
alpha = attrs.get_float("alpha", 1e-05)
beta = attrs.get_float("beta", 0.75)
return op.nn.lrn(children[0], size, axis, bias, alpha, beta)
def _l2_nomalize(children, attrs, odtype='float32'):
eps = attrs.get_float('eps')
axis = attrs.get_int_tuple('axis', None)
return op.nn.l2_normalize(children[0], eps, axis)
def _take(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', None)
return op.take(children[0], children[1], axis)
def _matmul(children, attrs, odtype='float32'):
input_1_t = op.transpose(children[1], axes=(1, 0))
return op.nn.dense(children[0], input_1_t)
def _collapse_sum(children, attrs, odtype='float32'):
for key in ["axis", "keepdims", "exclude"]:
if key in attrs.attrs:
raise NotImplementedError("Parameter '" + key + "' is not supported.")
return op.collapse_sum_like(children[0], children[1])
def _not_implemented(new_op):
def _impl(children, attrs, odtype='float32'):
raise NotImplementedError(str(new_op) + " is not implemented.")
return _impl
NNVM_OP_2_RELAY_OP = {
'flatten': _nn_batch_flatten,
'dense': _dense,
'softmax': _nn_softmax,
'softmax': _softmax_op(op.nn.softmax),
'log_softmax': _softmax_op(op.nn.log_softmax),
'conv2d': _conv2d,
'batch_norm': _batch_norm,
'max_pool2d': _max_pool2d,
......@@ -400,30 +320,47 @@ NNVM_OP_2_RELAY_OP = {
'dropout': _dropout,
'mean': _mean,
# Addition
'__add_scalar__': _add,
'broadcast_add': _add,
'elemwise_add': _add,
'__add_scalar__': _binop_scalar(op.add),
'broadcast_add' : _rename(op.add),
'elemwise_add' : _rename(op.add),
# Subtraction
'__sub_scalar__': _subtract,
'__rsub_scalar__': _rsubtract,
'broadcast_sub': _subtract,
'elemwise_sub': _subtract,
'__sub_scalar__' : _binop_scalar(op.subtract),
'__rsub_scalar__': _rbinop_scalar(op.subtract),
'broadcast_sub' : _rename(op.subtract),
'elemwise_sub' : _rename(op.subtract),
# Multiply
'__mul_scalar__': _multiply,
'broadcast_mul': _multiply,
'elemwise_mul': _multiply,
'__mul_scalar__': _binop_scalar(op.multiply),
'broadcast_mul' : _rename(op.multiply),
'elemwise_mul' : _rename(op.multiply),
# Division
'__div_scalar__': _divide,
'broadcast_div': _divide,
'elemwise_div': _divide,
'__div_scalar__': _binop_scalar(op.divide),
'broadcast_div' : _rename(op.divide),
'elemwise_div' : _rename(op.divide),
'broadcast_mod' : _rename(op.mod),
# Negative
'negative': _rename("negative"),
# Power
'__pow_scalar__': _binop_scalar(op.power),
'__rpow_scalar__': _rbinop_scalar(op.power),
'broadcast_pow': _rename(op.power),
# Sum
'sum': _reduce(op.sum),
'elemwise_sum': _elemwise_sum,
'collapse_sum': _collapse_sum,
'broadcast_max': _rename(op.maximum),
'broadcast_min': _rename(op.minimum),
# Comparsion
'greater': _greater,
'greater_equal': _greater_equal,
'less': _less,
'less_equal': _less_equal,
'greater': _compare(op.greater),
'broadcast_greater': _compare(op.greater),
'greater_equal': _compare(op.greater_equal),
'broadcast_greater_equal': _compare(op.greater_equal),
'less': _compare(op.less),
'broadcast_less': _compare(op.less),
'less_equal': _compare(op.less_equal),
'broadcast_less_equal': _compare(op.less_equal),
'broadcast_equal': _compare(op.equal),
'broadcast_not_equal': _compare(op.not_equal),
# Activations
'sigmoid': _rename('sigmoid'),
......@@ -432,13 +369,17 @@ NNVM_OP_2_RELAY_OP = {
'log': _rename('log'),
'tanh': _rename('tanh'),
'leaky_relu': _leaky_relu,
'prelu': _prelu,
'clip': _clip,
'round': _rename('round'),
'cast': _cast,
'expand_dims': _expand_dims,
'broadcast_to': broadcast_to,
'__rshift_scalar__': _rshift,
'copy': _copy,
'__lshift_scalar__': _binop_scalar(op.left_shift),
'__rshift_scalar__': _binop_scalar(op.right_shift),
'broadcast_left_shift': _rename(op.left_shift),
'broadcast_right_shift': _rename(op.right_shift),
'copy': _rename(op.copy),
'global_avg_pool2d': _global_avg_pool2d,
'avg_pool2d': _avg_pool2d,
'conv2d_transpose': _conv2d_transpose,
......@@ -449,6 +390,21 @@ NNVM_OP_2_RELAY_OP = {
'split': _split,
'squeeze': _squeeze,
'concatenate': _concatenate,
'abs': _rename(op.abs),
'ceil': _rename(op.ceil),
'floor': _rename(op.floor),
'trunc': _rename(op.trunc),
'take': _take,
'lrn': _lrn,
'l2_normalize': _l2_nomalize,
'matmul': _matmul,
'zeros_like': _rename(op.zeros_like),
'reshape_like': _rename(op.reshape_like),
'ones_like': _rename(op.ones_like),
'expand_like': _not_implemented("expand_like"),
'gather_nd': _not_implemented("gather_nd"),
'block_grad': _not_implemented("block_grad"),
}
......
......@@ -41,7 +41,7 @@ def _init_op(new_op):
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs):
def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
......@@ -50,13 +50,14 @@ def _softmax_op(new_op):
def _reduce(new_op):
"""Reduction ops like sum/min/max"""
def _impl(inputs, attrs):
def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False)
exclude = attrs.get_bool("exclude", False)
# use None for reduce over all axis.
axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims)
return new_op(inputs[0], axis=axis, keepdims=keepdims, exclude=exclude)
return _impl
......@@ -97,7 +98,7 @@ def _upsampling(inputs, attrs):
return _op.nn.upsampling(inputs[0], scale=scale)
def _elemwise_sum(inputs, _):
def _elemwise_sum(inputs, _, _dtype='float32'):
assert len(inputs) > 0
res = inputs[0]
for x in inputs[1:]:
......@@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _):
def _binop_scalar(new_op):
def _impl(inputs, attrs):
def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar)
return _impl
def _rbinop_scalar(new_op):
def _impl(inputs, attrs):
def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0])
return _impl
def _compare(new_op):
"""Compare ops like greater/less"""
def _impl(inputs, _, odtype='float32'):
assert len(inputs) == 2
return new_op(inputs[0], inputs[1]).astype(odtype)
return _impl
......@@ -476,8 +476,8 @@ bool TypeSolver::Solve() {
rnode->resolved = false;
this->ReportError(
RELAY_ERROR(
"an internal invariant was violdated while" \
"typechecking your program" <<
"an internal invariant was violdated while " \
"typechecking your program " <<
err.what()), rnode->location);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment