Commit 695647db by MORITA Kazutaka Committed by Jared Roesch

Improve NNVM to Relay conversion (#2734)

* Improve NNVM to Relay conversion

* fix pylint

* support __lshift_scalar__, abs, ceil, floor, and trunc to pass CI
parent 52d5cf89
...@@ -8,10 +8,12 @@ import numpy as np ...@@ -8,10 +8,12 @@ import numpy as np
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads from tvm.testing import check_numerical_grads
from tvm import relay
import nnvm import nnvm
from nnvm.compiler import graph_util from nnvm.compiler import graph_util
from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE
from nnvm.to_relay import to_relay
from .config import ctx_list from .config import ctx_list
def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None): def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
...@@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None, ...@@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
debug_stage = "running" debug_stage = "running"
nnvm_res = main_function(**np_inputs) nnvm_res = main_function(**np_inputs)
try:
logging.debug("checking to_relay conversion")
inputs = np_inputs_without_head_grads.copy()
func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**inputs)
m.set_input(**params)
m.run()
for i in range(out_len):
relay_out = m.get_output(i).asnumpy()
tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
except NotImplementedError as err:
# the NNVM operator is not supported yet
logging.warning(err)
if backward_graph is not None: if backward_graph is not None:
grad_var_names = [x.attr('name') for x in grad_input_vars] grad_var_names = [x.attr('name') for x in grad_input_vars]
nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])} nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}
......
...@@ -41,7 +41,7 @@ def _init_op(new_op): ...@@ -41,7 +41,7 @@ def _init_op(new_op):
def _softmax_op(new_op): def _softmax_op(new_op):
"""softmax/log_softmax""" """softmax/log_softmax"""
def _impl(inputs, attrs): def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
axis = attrs.get_int("axis", -1) axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis) return new_op(inputs[0], axis=axis)
...@@ -50,13 +50,14 @@ def _softmax_op(new_op): ...@@ -50,13 +50,14 @@ def _softmax_op(new_op):
def _reduce(new_op): def _reduce(new_op):
"""Reduction ops like sum/min/max""" """Reduction ops like sum/min/max"""
def _impl(inputs, attrs): def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", []) axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False) keepdims = attrs.get_bool("keepdims", False)
exclude = attrs.get_bool("exclude", False)
# use None for reduce over all axis. # use None for reduce over all axis.
axis = None if len(axis) == 0 else axis axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims) return new_op(inputs[0], axis=axis, keepdims=keepdims, exclude=exclude)
return _impl return _impl
...@@ -97,7 +98,7 @@ def _upsampling(inputs, attrs): ...@@ -97,7 +98,7 @@ def _upsampling(inputs, attrs):
return _op.nn.upsampling(inputs[0], scale=scale) return _op.nn.upsampling(inputs[0], scale=scale)
def _elemwise_sum(inputs, _): def _elemwise_sum(inputs, _, _dtype='float32'):
assert len(inputs) > 0 assert len(inputs) > 0
res = inputs[0] res = inputs[0]
for x in inputs[1:]: for x in inputs[1:]:
...@@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _): ...@@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _):
def _binop_scalar(new_op): def _binop_scalar(new_op):
def _impl(inputs, attrs): def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now # Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32") scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar) return new_op(inputs[0], scalar)
return _impl return _impl
def _rbinop_scalar(new_op): def _rbinop_scalar(new_op):
def _impl(inputs, attrs): def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1 assert len(inputs) == 1
scalar = attrs.get_float("scalar") scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now # Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32") scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0]) return new_op(scalar, inputs[0])
return _impl return _impl
def _compare(new_op):
"""Compare ops like greater/less"""
def _impl(inputs, _, odtype='float32'):
assert len(inputs) == 2
return new_op(inputs[0], inputs[1]).astype(odtype)
return _impl
...@@ -476,8 +476,8 @@ bool TypeSolver::Solve() { ...@@ -476,8 +476,8 @@ bool TypeSolver::Solve() {
rnode->resolved = false; rnode->resolved = false;
this->ReportError( this->ReportError(
RELAY_ERROR( RELAY_ERROR(
"an internal invariant was violdated while" \ "an internal invariant was violdated while " \
"typechecking your program" << "typechecking your program " <<
err.what()), rnode->location); err.what()), rnode->location);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment