Commit 695647db by MORITA Kazutaka Committed by Jared Roesch

Improve NNVM to Relay conversion (#2734)

* Improve NNVM to Relay conversion

* fix pylint

* support __lshift_scalar__, abs, ceil, floor, and trunc to pass CI
parent 52d5cf89
......@@ -8,10 +8,12 @@ import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads
from tvm import relay
import nnvm
from nnvm.compiler import graph_util
from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE
from nnvm.to_relay import to_relay
from .config import ctx_list
def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
......@@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
debug_stage = "running"
nnvm_res = main_function(**np_inputs)
try:
logging.debug("checking to_relay conversion")
inputs = np_inputs_without_head_grads.copy()
func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**inputs)
m.set_input(**params)
m.run()
for i in range(out_len):
relay_out = m.get_output(i).asnumpy()
tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
except NotImplementedError as err:
# the NNVM operator is not supported yet
logging.warning(err)
if backward_graph is not None:
grad_var_names = [x.attr('name') for x in grad_input_vars]
nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}
......
......@@ -41,7 +41,7 @@ def _init_op(new_op):
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs):
def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
......@@ -50,13 +50,14 @@ def _softmax_op(new_op):
def _reduce(new_op):
"""Reduction ops like sum/min/max"""
def _impl(inputs, attrs):
def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False)
exclude = attrs.get_bool("exclude", False)
# use None for reduce over all axis.
axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims)
return new_op(inputs[0], axis=axis, keepdims=keepdims, exclude=exclude)
return _impl
......@@ -97,7 +98,7 @@ def _upsampling(inputs, attrs):
return _op.nn.upsampling(inputs[0], scale=scale)
def _elemwise_sum(inputs, _):
def _elemwise_sum(inputs, _, _dtype='float32'):
assert len(inputs) > 0
res = inputs[0]
for x in inputs[1:]:
......@@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _):
def _binop_scalar(new_op):
def _impl(inputs, attrs):
def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar)
return _impl
def _rbinop_scalar(new_op):
def _impl(inputs, attrs):
def _impl(inputs, attrs, odtype='float32'):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0])
return _impl
def _compare(new_op):
"""Compare ops like greater/less"""
def _impl(inputs, _, odtype='float32'):
assert len(inputs) == 2
return new_op(inputs[0], inputs[1]).astype(odtype)
return _impl
......@@ -476,8 +476,8 @@ bool TypeSolver::Solve() {
rnode->resolved = false;
this->ReportError(
RELAY_ERROR(
"an internal invariant was violdated while" \
"typechecking your program" <<
"an internal invariant was violdated while " \
"typechecking your program " <<
err.what()), rnode->location);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment