Commit 58398d38 by Jared Roesch Committed by Tianqi Chen

Port from_nnvm to NNVM as to_relay (#2144)

parent 990521dd
import nnvm
from nnvm import testing
from nnvm import to_relay
import tvm
from tvm.relay import ir_pass
from tvm.relay import create_executor
from tvm.contrib import graph_runtime
import numpy as np
def check_model(sym, shapes, dtypes, params):
net = nnvm.graph.create(sym)
graph_json, mod, params = nnvm.compiler.build(
net,
'llvm',
shape=shapes,
dtype=dtypes,
params=params)
nnvm_rts = graph_runtime.create(graph_json, mod, tvm.cpu(0))
inputs = {}
for name in shapes:
np_array = np.random.rand(*shapes[name]).astype('float32')
inputs[name] = tvm.nd.array(np_array)
nnvm_rts.set_input(**params)
nnvm_rts.run(**inputs)
nnvm_out = nnvm_rts.get_output(0)
relay_model, params = to_relay.to_relay(net, shapes, dtypes, params)
relay_model = ir_pass.infer_type(relay_model)
relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm')
inputs.update(params)
relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values()))
np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy())
# def test_mlp():
# mlp, params = testing.mlp.get_workload(1)
# shapes = { "data": (10, 3, 224, 224) }
# dtypes = { "data": 'float32' }
# check_model(mlp, shapes, dtypes, params)
if __name__ == "__main__":
test_mlp()
...@@ -101,11 +101,64 @@ class StrAttrsDict(object): ...@@ -101,11 +101,64 @@ class StrAttrsDict(object):
""" """
if key in self.attrs: if key in self.attrs:
tshape = self.attrs[key] tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()').split(',')) return tuple(int(x.strip()) for x in tshape.strip('()[]').split(','))
if isinstance(default, RequiredAttr): if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return default return default
def get_tuple_tuple_int(self, key, default=RequiredAttr()):
"""Get int list attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
value = self.attrs[key]
seq = []
for tup in value.strip('()').split('),'):
tup = tup.strip('[]()')
els = [int(x.strip('( ')) for x in tup.split(',')]
seq.append(tuple(els))
return tuple(seq)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_int_list(self, key, default=RequiredAttr()):
"""Get int list attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('[]()').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_bool(self, key, default=RequiredAttr()): def get_bool(self, key, default=RequiredAttr()):
"""Get bool tuple attribute """Get bool tuple attribute
......
...@@ -8,138 +8,14 @@ from .. import expr as _expr ...@@ -8,138 +8,14 @@ from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import StrAttrsDict from .common import StrAttrsDict
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
from .nnvm_common import _clip, _transpose, _upsampling
from .nnvm_common import _elemwise_sum, _reshape
from .nnvm_common import _warn_not_used
__all__ = ['from_mxnet'] __all__ = ['from_mxnet']
def _get_relay_op(op_name):
op = getattr(_op, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
def _warn_not_used(attr, op='nnvm'):
import warnings
err = "{} is ignored in {}.".format(attr, op)
warnings.warn(err)
def _rename(new_op):
if isinstance(new_op, str):
new_op = _get_relay_op(new_op)
# attrs are ignored.
def impl(inputs, _):
return new_op(*inputs)
return impl
def _reshape(inputs, attrs):
if attrs.get_bool("reverse", False):
raise RuntimeError("reshape do not support option reverse")
shape = attrs.get_int_tuple("shape")
return _op.reshape(inputs[0], newshape=shape)
def _init_op(new_op):
"""Init ops like zeros/ones"""
def _impl(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
return new_op(shape=shape, dtype=dtype)
return _impl
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
return _impl
def _reduce(new_op):
"""Reduction ops like sum/min/max"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False)
# use None for reduce over all axis.
axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims)
return _impl
def _arg_reduce(new_op):
"""Arg Reduction ops like argmin/argmax"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int("axis", None)
keepdims = attrs.get_bool("keepdims", False)
res = new_op(inputs[0], axis=[axis], keepdims=keepdims)
# cast to dtype.
res = res.astype("float32")
return res
return _impl
def _cast(inputs, attrs):
"""Type cast"""
dtype = attrs.get_str("dtype")
return _op.cast(inputs[0], dtype=dtype)
def _clip(inputs, attrs):
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
def _transpose(inputs, attrs):
axes = attrs.get_int_tuple("axes", None)
# translate default case
axes = None if len(axes) == 0 else axes
return _op.transpose(inputs[0], axes=axes)
def _upsampling(inputs, attrs):
scale = attrs.get_int("scale")
return _op.nn.upsampling(inputs[0], scale=scale)
def _elemwise_sum(inputs, _):
assert len(inputs) > 0
res = inputs[0]
for x in inputs[1:]:
res = _op.add(res, x)
return res
def _binop_scalar(new_op):
def _impl(inputs, attrs):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
return new_op(inputs[0], scalar)
return _impl
def _rbinop_scalar(new_op):
def _impl(inputs, attrs):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
return new_op(scalar, inputs[0])
return _impl
# All the functions with _mx prefix specific to MXNet.
# The functions without _mx prefix can be reused for
# NNVMv1 conversion to _op.
def _mx_fully_connected(inputs, attrs): def _mx_fully_connected(inputs, attrs):
import mxnet as mx import mxnet as mx
units = attrs.get_int("num_hidden") units = attrs.get_int("num_hidden")
...@@ -493,6 +369,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): ...@@ -493,6 +369,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
jnodes = jgraph["nodes"] jnodes = jgraph["nodes"]
node_map = {} node_map = {}
for nid, node in enumerate(jnodes): for nid, node in enumerate(jnodes):
children = [node_map[e[0]][e[1]] for e in node["inputs"]] children = [node_map[e[0]][e[1]] for e in node["inputs"]]
attrs = StrAttrsDict(node.get("attrs", {})) attrs = StrAttrsDict(node.get("attrs", {}))
...@@ -501,7 +378,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): ...@@ -501,7 +378,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
if op_name == "null": if op_name == "null":
shape = shape_dict[node_name] if node_name in shape_dict else None shape = shape_dict[node_name] if node_name in shape_dict else None
if isinstance(dtype_info, dict): if isinstance(dtype_info, dict):
dtype = dtype_info[node_name] if node_name in dtype_dict else "float32" dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
else: else:
dtype = dtype_info dtype = dtype_info
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
......
# pylint: disable=invalid-name, import-self, len-as-condition
"""Utility functions common to NNVM and MxNet conversion."""
from __future__ import absolute_import as _abs
from .. import expr as _expr
from .. import op as _op
def _get_relay_op(op_name):
op = _op
for path in op_name.split("."):
op = getattr(op, path)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
def _warn_not_used(attr, op='nnvm'):
import warnings
err = "{} is ignored in {}.".format(attr, op)
warnings.warn(err)
def _rename(new_op):
if isinstance(new_op, str):
new_op = _get_relay_op(new_op)
# attrs are ignored.
def impl(inputs, _, _dtype='float32'):
return new_op(*inputs)
return impl
def _reshape(inputs, attrs):
if attrs.get_bool("reverse", False):
raise RuntimeError("reshape do not support option reverse")
shape = attrs.get_int_tuple("shape")
return _op.reshape(inputs[0], newshape=shape)
def _init_op(new_op):
"""Init ops like zeros/ones"""
def _impl(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
return new_op(shape=shape, dtype=dtype)
return _impl
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
return _impl
def _reduce(new_op):
"""Reduction ops like sum/min/max"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
keepdims = attrs.get_bool("keepdims", False)
# use None for reduce over all axis.
axis = None if len(axis) == 0 else axis
return new_op(inputs[0], axis=axis, keepdims=keepdims)
return _impl
def _arg_reduce(new_op):
"""Arg Reduction ops like argmin/argmax"""
def _impl(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int("axis", None)
keepdims = attrs.get_bool("keepdims", False)
res = new_op(inputs[0], axis=[axis], keepdims=keepdims)
# cast to dtype.
res = res.astype("float32")
return res
return _impl
def _cast(inputs, attrs):
"""Type cast"""
dtype = attrs.get_str("dtype")
return inputs[0].astype(dtype=dtype)
def _clip(inputs, attrs):
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
def _transpose(inputs, attrs):
axes = attrs.get_int_tuple("axes", None)
# translate default case
axes = None if len(axes) == 0 else axes
return _op.transpose(inputs[0], axes=axes)
def _upsampling(inputs, attrs):
scale = attrs.get_int("scale")
return _op.nn.upsampling(inputs[0], scale=scale)
def _elemwise_sum(inputs, _):
assert len(inputs) > 0
res = inputs[0]
for x in inputs[1:]:
res = _op.add(res, x)
return res
def _binop_scalar(new_op):
def _impl(inputs, attrs):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
return new_op(inputs[0], scalar)
return _impl
def _rbinop_scalar(new_op):
def _impl(inputs, attrs):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
scalar = _expr.const(scalar, dtype="float32")
return new_op(scalar, inputs[0])
return _impl
...@@ -9,6 +9,7 @@ from .op import schedule_injective, OpPattern ...@@ -9,6 +9,7 @@ from .op import schedule_injective, OpPattern
schedule_injective = _reg.schedule_injective schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective
_reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("broadcast_to_like", schedule_broadcast) _reg.register_schedule("broadcast_to_like", schedule_broadcast)
_reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast)
......
...@@ -243,14 +243,11 @@ def schedule_l2_normalize(attrs, outs, target): ...@@ -243,14 +243,11 @@ def schedule_l2_normalize(attrs, outs, target):
reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
# Upsampling
@reg.register_schedule("nn.upsampling") reg.register_schedule("nn.upsampling", reg.schedule_injective)
def schedule_upsampling(_, outs, target): def schedule_upsampling(_, outs, target):
"""Schedule definition of upsampling""" """Schedule definition of upsampling"""
with target: with target:
return topi.generic.schedule_injective(outs) return topi.generic.schedule_injective(outs)
reg.register_pattern("nn.upsampling", OpPattern.INJECTIVE)
# pad # pad
reg.register_schedule("nn.pad", schedule_broadcast) reg.register_schedule("nn.pad", schedule_broadcast)
...@@ -253,6 +253,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ...@@ -253,6 +253,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
size_t size = 1; size_t size = 1;
for (IndexExpr dim : ttype->shape) { for (IndexExpr dim : ttype->shape) {
const int64_t* pval = as_const_int(dim); const int64_t* pval = as_const_int(dim);
CHECK_GE(*pval, 0) <<
"can not allocate memory for tensor with negative shape" <<
*pval;
CHECK(pval != nullptr) CHECK(pval != nullptr)
<< "Cannot allocate memory symbolic tensor shape " << "Cannot allocate memory symbolic tensor shape "
<< ttype->shape; << ttype->shape;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// Alpha equal handler for relay. // Alpha Equal handler for Relay.
class AlphaEqualHandler: class AlphaEqualHandler:
public AttrsEqualHandler, public AttrsEqualHandler,
public TypeFunctor<bool(const Type&, const Type&)>, public TypeFunctor<bool(const Type&, const Type&)>,
...@@ -26,7 +26,7 @@ class AlphaEqualHandler: ...@@ -26,7 +26,7 @@ class AlphaEqualHandler:
* Check equality of two nodes. * Check equality of two nodes.
* \param lhs The left hand operand. * \param lhs The left hand operand.
* \param rhs The right hand operand. * \param rhs The right hand operand.
* \return The compare result. * \return The comparison result.
*/ */
bool Equal(const NodeRef& lhs, const NodeRef& rhs) { bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
...@@ -46,7 +46,7 @@ class AlphaEqualHandler: ...@@ -46,7 +46,7 @@ class AlphaEqualHandler:
* Check equality of two attributes. * Check equality of two attributes.
* \param lhs The left hand operand. * \param lhs The left hand operand.
* \param rhs The right hand operand. * \param rhs The right hand operand.
* \return The compare result. * \return The comparison result.
*/ */
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqualHandler::Equal(lhs, rhs); return AttrsEqualHandler::Equal(lhs, rhs);
...@@ -55,7 +55,7 @@ class AlphaEqualHandler: ...@@ -55,7 +55,7 @@ class AlphaEqualHandler:
* Check equality of two types. * Check equality of two types.
* \param lhs The left hand operand. * \param lhs The left hand operand.
* \param rhs The right hand operand. * \param rhs The right hand operand.
* \return The compare result. * \return the comparison result.
*/ */
bool TypeEqual(const Type& lhs, const Type& rhs) { bool TypeEqual(const Type& lhs, const Type& rhs) {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
...@@ -72,7 +72,7 @@ class AlphaEqualHandler: ...@@ -72,7 +72,7 @@ class AlphaEqualHandler:
* *
* \param lhs The left hand operand. * \param lhs The left hand operand.
* \param rhs The right hand operand. * \param rhs The right hand operand.
* \return The compare result. * \return The comparison result.
*/ */
bool ExprEqual(const Expr& lhs, const Expr& rhs) { bool ExprEqual(const Expr& lhs, const Expr& rhs) {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/build_module.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include <topi/nn/upsampling.h> #include <topi/nn/upsampling.h>
#include <vector>
#include "../op_common.h"
#include "../layout.h" #include "../layout.h"
namespace tvm { namespace tvm {
...@@ -86,26 +89,37 @@ RELAY_REGISTER_OP("nn.upsampling") ...@@ -86,26 +89,37 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
.add_type_rel("UpSampling", UpSamplingRel) .add_type_rel("UpSampling", UpSamplingRel)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target) { const Target& target) {
const auto* param = attrs.as<UpSamplingAttrs>(); const auto* uattrs = attrs.as<UpSamplingAttrs>();
const auto* out_ttype = out_type.as<TensorTypeNode>(); CHECK(uattrs != nullptr);
CHECK(param != nullptr); auto out_tt = out_type.as<TensorTypeNode>();
CHECK(param->layout == "NCHW" || param->layout == "NHWC"); CHECK(out_tt) << "expected a tensor type: " << out_type;
CHECK(out_ttype != nullptr); CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
Array<IndexExpr> oshape; << "unknown layout: " << uattrs->layout;
if (param->layout == "NCHW") {
oshape.push_back(out_ttype->shape[2]); Array<HalideIR::Expr> oshape;
oshape.push_back(out_ttype->shape[3]); if (uattrs->layout == "NCHW") {
} else if (param->layout == "NHWC") { oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_ttype->shape[1]); oshape.push_back(out_tt->shape[3]);
oshape.push_back(out_ttype->shape[2]); } else if (uattrs->layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
} }
return Array<Tensor>{ topi::nn::upsampling(inputs[0], oshape, param->layout, param->method)};
return Array<Tensor>{
topi::nn::upsampling(
inputs[0],
oshape,
uattrs->layout,
uattrs->method)
};
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <algorithm> #include <algorithm>
#include "topi/tags.h" #include "topi/tags.h"
#include "topi/elemwise.h"
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
...@@ -288,7 +289,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, ...@@ -288,7 +289,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
* \return A Tensor resized to given shape * \return A Tensor resized to given shape
*/ */
inline Tensor resize_bilinear(const Tensor& input, inline Tensor resize_bilinear(const Tensor& input,
const Array<Expr>& shape, const Array<tvm::Expr>& shape,
std::string layout = "NCHW", std::string layout = "NCHW",
bool align_corners = false, bool align_corners = false,
std::string name = "tensor", std::string name = "tensor",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment