Commit 695647db by MORITA Kazutaka Committed by Jared Roesch

Improve NNVM to Relay conversion (#2734)

* Improve NNVM to Relay conversion

* fix pylint

* support __lshift_scalar__, abs, ceil, floor, and trunc to pass CI
parent 52d5cf89
...@@ -8,10 +8,12 @@ import numpy as np ...@@ -8,10 +8,12 @@ import numpy as np
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads from tvm.testing import check_numerical_grads
from tvm import relay
import nnvm import nnvm
from nnvm.compiler import graph_util from nnvm.compiler import graph_util
from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE
from nnvm.to_relay import to_relay
from .config import ctx_list from .config import ctx_list
def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None): def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
...@@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None, ...@@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
debug_stage = "running" debug_stage = "running"
nnvm_res = main_function(**np_inputs) nnvm_res = main_function(**np_inputs)
try:
logging.debug("checking to_relay conversion")
inputs = np_inputs_without_head_grads.copy()
func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**inputs)
m.set_input(**params)
m.run()
for i in range(out_len):
relay_out = m.get_output(i).asnumpy()
tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
except NotImplementedError as err:
# the NNVM operator is not supported yet
logging.warning(err)
if backward_graph is not None: if backward_graph is not None:
grad_var_names = [x.attr('name') for x in grad_input_vars] grad_var_names = [x.attr('name') for x in grad_input_vars]
nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])} nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}
......
...@@ -6,7 +6,8 @@ import numpy ...@@ -6,7 +6,8 @@ import numpy
from tvm import relay, nd from tvm import relay, nd
from tvm.relay import op, expr, var from tvm.relay import op, expr, var
from tvm.relay.frontend.common import StrAttrsDict from tvm.relay.frontend.common import StrAttrsDict
from tvm.relay.frontend.nnvm_common import _rename from tvm.relay.frontend.nnvm_common import _rename, _binop_scalar, _rbinop_scalar, \
_elemwise_sum, _softmax_op, _compare, _reduce
from .symbol import Symbol from .symbol import Symbol
from .compiler import graph_attr from .compiler import graph_attr
from .graph import create as graph_create from .graph import create as graph_create
...@@ -25,11 +26,6 @@ def _dense(children, attrs, odtype='float32'): ...@@ -25,11 +26,6 @@ def _dense(children, attrs, odtype='float32'):
else: else:
return dense return dense
def _nn_softmax(children, attrs, odtype='float32'):
assert len(children) == 1
axis = attrs.get_int('axis', 1)
return op.nn.softmax(children[0], axis)
def _conv2d(children, attrs, odtype='float32'): def _conv2d(children, attrs, odtype='float32'):
use_bias = attrs.get_bool('use_bias', True) use_bias = attrs.get_bool('use_bias', True)
...@@ -150,84 +146,6 @@ def _transpose(children, attrs, odtype='float32'): ...@@ -150,84 +146,6 @@ def _transpose(children, attrs, odtype='float32'):
return op.transpose(children[0], axes=axes) return op.transpose(children[0], axes=axes)
def _add(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.add(left, right)
def _subtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.subtract(left, right)
def _rsubtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.subtract(right, left)
def _multiply(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.multiply(left, right)
def _divide(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.divide(left, right)
def _rshift(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype='int32')
else:
assert len(children) == 2
left = children[0]
right = children[1]
return op.right_shift(left, right)
def _clip(children, attrs, odtype='float32'): def _clip(children, attrs, odtype='float32'):
a_min = attrs.get_float('a_min') a_min = attrs.get_float('a_min')
a_max = attrs.get_float('a_max') a_max = attrs.get_float('a_max')
...@@ -255,9 +173,6 @@ def broadcast_to(children, attrs, odtype='float32'): ...@@ -255,9 +173,6 @@ def broadcast_to(children, attrs, odtype='float32'):
rconst = relay.Constant(nd.array(array)) rconst = relay.Constant(nd.array(array))
return op.broadcast_to_like(data, rconst) return op.broadcast_to_like(data, rconst)
def _copy(children, attrs, odtype='float32'):
return op.copy(children[0])
def _global_avg_pool2d(children, attrs, odtype='float32'): def _global_avg_pool2d(children, attrs, odtype='float32'):
data = children[0] data = children[0]
...@@ -309,42 +224,10 @@ def _full_like(children, attrs, odtype='float32'): ...@@ -309,42 +224,10 @@ def _full_like(children, attrs, odtype='float32'):
return op.full_like(children[0], fill_value) return op.full_like(children[0], fill_value)
def _greater(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type')
if out_type:
return op.greater(children[0], children[1]).astype(out_type)
else:
return op.greater(children[0], children[1])
def _greater_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.greater_equal(children[0], children[1]).astype(out_type)
else:
return op.greater_equal(children[0], children[1])
def _less(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less(children[0], children[1]).astype(out_type)
else:
return op.less(children[0], children[1])
def _less_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less_equal(children[0], children[1]).astype(out_type)
else:
return op.less_equal(children[0], children[1])
def _strided_slice(children, attrs, odtype='float32'): def _strided_slice(children, attrs, odtype='float32'):
begin = attrs.get_int_list('begin') begin = attrs.get_int_list('begin')
end = attrs.get_int_list('end') end = attrs.get_int_list('end')
strides = attrs.get_int_list('strides', None) strides = attrs.get_int_list('stride', None)
return op.strided_slice(children[0], begin, end, strides=strides) return op.strided_slice(children[0], begin, end, strides=strides)
...@@ -358,14 +241,11 @@ def _split(children, attrs, odtype='float32'): ...@@ -358,14 +241,11 @@ def _split(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 0) axis = attrs.get_int('axis', 0)
return op.split(children[0], indices_or_sections, axis) return op.split(children[0], indices_or_sections, axis).astuple()
def _squeeze(children, attrs, odtype='float32'): def _squeeze(children, attrs, odtype='float32'):
axis = None axis = attrs.get_int_tuple('axis', None)
try: axis = [axis] if isinstance(axis, int) else axis
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
return op.squeeze(children[0], axis) return op.squeeze(children[0], axis)
...@@ -378,20 +258,60 @@ def _dropout(children, attrs, odtype='float32'): ...@@ -378,20 +258,60 @@ def _dropout(children, attrs, odtype='float32'):
return op.nn.dropout(children[0], rate) return op.nn.dropout(children[0], rate)
def _mean(children, attrs, odtype='float32'): def _mean(children, attrs, odtype='float32'):
axis = None axis = attrs.get_int_tuple('axis', None)
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
keepdims = attrs.get_bool('keepdims') keepdims = attrs.get_bool('keepdims')
return op.mean(children[0], axis, keepdims) return op.mean(children[0], axis, keepdims)
def _prelu(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 1)
return op.nn.prelu(children[0], children[1], axis)
def _lrn(children, attrs, odtype='float32'):
size = attrs.get_int("size", 5)
axis = attrs.get_int("axis", 1)
bias = attrs.get_float("bias", 2)
alpha = attrs.get_float("alpha", 1e-05)
beta = attrs.get_float("beta", 0.75)
return op.nn.lrn(children[0], size, axis, bias, alpha, beta)
def _l2_nomalize(children, attrs, odtype='float32'):
eps = attrs.get_float('eps')
axis = attrs.get_int_tuple('axis', None)
return op.nn.l2_normalize(children[0], eps, axis)
def _take(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', None)
return op.take(children[0], children[1], axis)
def _matmul(children, attrs, odtype='float32'):
input_1_t = op.transpose(children[1], axes=(1, 0))
return op.nn.dense(children[0], input_1_t)
def _collapse_sum(children, attrs, odtype='float32'):
for key in ["axis", "keepdims", "exclude"]:
if key in attrs.attrs:
raise NotImplementedError("Parameter '" + key + "' is not supported.")
return op.collapse_sum_like(children[0], children[1])
def _not_implemented(new_op):
def _impl(children, attrs, odtype='float32'):
raise NotImplementedError(str(new_op) + " is not implemented.")
return _impl
NNVM_OP_2_RELAY_OP = { NNVM_OP_2_RELAY_OP = {
'flatten': _nn_batch_flatten, 'flatten': _nn_batch_flatten,
'dense': _dense, 'dense': _dense,
'softmax': _nn_softmax, 'softmax': _softmax_op(op.nn.softmax),
'log_softmax': _softmax_op(op.nn.log_softmax),
'conv2d': _conv2d, 'conv2d': _conv2d,
'batch_norm': _batch_norm, 'batch_norm': _batch_norm,
'max_pool2d': _max_pool2d, 'max_pool2d': _max_pool2d,
...@@ -400,30 +320,47 @@ NNVM_OP_2_RELAY_OP = { ...@@ -400,30 +320,47 @@ NNVM_OP_2_RELAY_OP = {
'dropout': _dropout, 'dropout': _dropout,
'mean': _mean, 'mean': _mean,
# Addition # Addition
'__add_scalar__': _add, '__add_scalar__': _binop_scalar(op.add),
'broadcast_add': _add, 'broadcast_add' : _rename(op.add),
'elemwise_add': _add, 'elemwise_add' : _rename(op.add),
# Subtraction # Subtraction
'__sub_scalar__': _subtract, '__sub_scalar__' : _binop_scalar(op.subtract),
'__rsub_scalar__': _rsubtract, '__rsub_scalar__': _rbinop_scalar(op.subtract),
'broadcast_sub': _subtract, 'broadcast_sub' : _rename(op.subtract),
'elemwise_sub': _subtract, 'elemwise_sub' : _rename(op.subtract),
# Multiply # Multiply
'__mul_scalar__': _multiply, '__mul_scalar__': _binop_scalar(op.multiply),
'broadcast_mul': _multiply, 'broadcast_mul' : _rename(op.multiply),
'elemwise_mul': _multiply, 'elemwise_mul' : _rename(op.multiply),
# Division # Division
'__div_scalar__': _divide, '__div_scalar__': _binop_scalar(op.divide),
'broadcast_div': _divide, 'broadcast_div' : _rename(op.divide),
'elemwise_div': _divide, 'elemwise_div' : _rename(op.divide),
'broadcast_mod' : _rename(op.mod),
# Negative # Negative
'negative': _rename("negative"), 'negative': _rename("negative"),
# Power
'__pow_scalar__': _binop_scalar(op.power),
'__rpow_scalar__': _rbinop_scalar(op.power),
'broadcast_pow': _rename(op.power),
# Sum
'sum': _reduce(op.sum),
'elemwise_sum': _elemwise_sum,
'collapse_sum': _collapse_sum,
'broadcast_max': _rename(op.maximum),
'broadcast_min': _rename(op.minimum),
# Comparsion # Comparsion
'greater': _greater, 'greater': _compare(op.greater),
'greater_equal': _greater_equal, 'broadcast_greater': _compare(op.greater),
'less': _less, 'greater_equal': _compare(op.greater_equal),
'less_equal': _less_equal, 'broadcast_greater_equal': _compare(op.greater_equal),
'less': _compare(op.less),
'broadcast_less': _compare(op.less),
'less_equal': _compare(op.less_equal),
'broadcast_less_equal': _compare(op.less_equal),
'broadcast_equal': _compare(op.equal),
'broadcast_not_equal': _compare(op.not_equal),
# Activations # Activations
'sigmoid': _rename('sigmoid'), 'sigmoid': _rename('sigmoid'),
...@@ -432,13 +369,17 @@ NNVM_OP_2_RELAY_OP = { ...@@ -432,13 +369,17 @@ NNVM_OP_2_RELAY_OP = {
'log': _rename('log'), 'log': _rename('log'),
'tanh': _rename('tanh'), 'tanh': _rename('tanh'),
'leaky_relu': _leaky_relu, 'leaky_relu': _leaky_relu,
'prelu': _prelu,
'clip': _clip, 'clip': _clip,
'round': _rename('round'), 'round': _rename('round'),
'cast': _cast, 'cast': _cast,
'expand_dims': _expand_dims, 'expand_dims': _expand_dims,
'broadcast_to': broadcast_to, 'broadcast_to': broadcast_to,
'__rshift_scalar__': _rshift, '__lshift_scalar__': _binop_scalar(op.left_shift),
'copy': _copy, '__rshift_scalar__': _binop_scalar(op.right_shift),
'broadcast_left_shift': _rename(op.left_shift),
'broadcast_right_shift': _rename(op.right_shift),
'copy': _rename(op.copy),
'global_avg_pool2d': _global_avg_pool2d, 'global_avg_pool2d': _global_avg_pool2d,
'avg_pool2d': _avg_pool2d, 'avg_pool2d': _avg_pool2d,
'conv2d_transpose': _conv2d_transpose, 'conv2d_transpose': _conv2d_transpose,
...@@ -449,6 +390,21 @@ NNVM_OP_2_RELAY_OP = { ...@@ -449,6 +390,21 @@ NNVM_OP_2_RELAY_OP = {
'split': _split, 'split': _split,
'squeeze': _squeeze, 'squeeze': _squeeze,
'concatenate': _concatenate, 'concatenate': _concatenate,
'abs': _rename(op.abs),
'ceil': _rename(op.ceil),
'floor': _rename(op.floor),
'trunc': _rename(op.trunc),
'take': _take,
'lrn': _lrn,
'l2_normalize': _l2_nomalize,
'matmul': _matmul,
'zeros_like': _rename(op.zeros_like),
'reshape_like': _rename(op.reshape_like),
'ones_like': _rename(op.ones_like),
'expand_like': _not_implemented("expand_like"),
'gather_nd': _not_implemented("gather_nd"),
'block_grad': _not_implemented("block_grad"),
} }
......
...@@ -41,7 +41,7 @@ def _init_op(new_op): ...@@ -41,7 +41,7 @@ def _init_op(new_op):
def _softmax_op(new_op): def _softmax_op(new_op):
"""softmax/log_softmax""" """softmax/log_softmax"""
def _impl(inputs, attrs): def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
axis = attrs.get_int("axis", -1) axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis) return new_op(inputs[0], axis=axis)
...@@ -50,13 +50,14 @@ def _softmax_op(new_op): ...@@ -50,13 +50,14 @@ def _softmax_op(new_op):
def _reduce(new_op): def _reduce(new_op):
"""Reduction ops like sum/min/max""" """Reduction ops like sum/min/max"""
def _impl(inputs, attrs): def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", []) axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False) keepdims = attrs.get_bool("keepdims", False)
exclude = attrs.get_bool("exclude", False)
# use None for reduce over all axis. # use None for reduce over all axis.
axis = None if len(axis) == 0 else axis axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims) return new_op(inputs[0], axis=axis, keepdims=keepdims, exclude=exclude)
return _impl return _impl
...@@ -97,7 +98,7 @@ def _upsampling(inputs, attrs): ...@@ -97,7 +98,7 @@ def _upsampling(inputs, attrs):
return _op.nn.upsampling(inputs[0], scale=scale) return _op.nn.upsampling(inputs[0], scale=scale)
def _elemwise_sum(inputs, _): def _elemwise_sum(inputs, _, _dtype='float32'):
assert len(inputs) > 0 assert len(inputs) > 0
res = inputs[0] res = inputs[0]
for x in inputs[1:]: for x in inputs[1:]:
...@@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _): ...@@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _):
def _binop_scalar(new_op): def _binop_scalar(new_op):
def _impl(inputs, attrs): def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now # Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32") scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar) return new_op(inputs[0], scalar)
return _impl return _impl
def _rbinop_scalar(new_op): def _rbinop_scalar(new_op):
def _impl(inputs, attrs): def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now # Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32") scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0]) return new_op(scalar, inputs[0])
return _impl return _impl
def _compare(new_op):
"""Compare ops like greater/less"""
def _impl(inputs, _, odtype='float32'):
assert len(inputs) == 2
return new_op(inputs[0], inputs[1]).astype(odtype)
return _impl
...@@ -476,8 +476,8 @@ bool TypeSolver::Solve() { ...@@ -476,8 +476,8 @@ bool TypeSolver::Solve() {
rnode->resolved = false; rnode->resolved = false;
this->ReportError( this->ReportError(
RELAY_ERROR( RELAY_ERROR(
"an internal invariant was violdated while" \ "an internal invariant was violdated while " \
"typechecking your program" << "typechecking your program " <<
err.what()), rnode->location); err.what()), rnode->location);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment