Unverified Commit 9d44279e by Tianqi Chen Committed by GitHub

[RELAY][Refactor] TextPrinter, move ret_type after body in Function. (#1918)

parent 4c13ee22
...@@ -131,6 +131,13 @@ class BaseAttrsNode : public Node { ...@@ -131,6 +131,13 @@ class BaseAttrsNode : public Node {
*/ */
inline void PrintDocString(std::ostream &os) const; // NOLINT(*) inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
/*! /*!
* \brief Visit attributes that do not equal the default value.
*
* \note This is useful to extract fields for concise printing.
* \param v The visitor
*/
TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
/*!
* \brief Get the field information * \brief Get the field information
* \return The fields in the Attrs. * \return The fields in the Attrs.
*/ */
...@@ -199,6 +206,7 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -199,6 +206,7 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict); TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
// implementations // implementations
void VisitAttrs(AttrVisitor* v) final; void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final; Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final; bool ContentEqual(const Node* other) const final;
...@@ -300,15 +308,15 @@ struct AttrNopEntry { ...@@ -300,15 +308,15 @@ struct AttrNopEntry {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) { TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this; return *this;
} }
}; };
...@@ -603,7 +611,7 @@ class AttrDocEntry { ...@@ -603,7 +611,7 @@ class AttrDocEntry {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) { TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
std::ostringstream os; std::ostringstream os;
os << info_->type_info << ", default=" << value; os << info_->type_info << ", default=" << value;
info_->type_info = os.str(); info_->type_info = os.str();
...@@ -649,6 +657,57 @@ class AttrExistVisitor { ...@@ -649,6 +657,57 @@ class AttrExistVisitor {
return AttrNopEntry(); return AttrNopEntry();
} }
}; };
template<typename T>
struct AttrTriggerNonDefaultEntry {
using TSelf = AttrTriggerNonDefaultEntry<T>;
// constructor
AttrTriggerNonDefaultEntry(
AttrVisitor* visitor, const char* key, T* data)
: visitor_(visitor), key_(key), data_(data) {}
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
if (trigger_) {
visitor_->Visit(key_, data_);
}
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
}
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
private:
AttrVisitor* visitor_;
const char * key_;
T *data_;
bool trigger_{true};
};
class AttrNonDefaultVisitor {
public:
explicit AttrNonDefaultVisitor(AttrVisitor* visitor)
: visitor_(visitor) {
}
template<typename T>
AttrTriggerNonDefaultEntry<T>
operator()(const char* key, T* value) {
return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
}
private:
AttrVisitor* visitor_;
};
} // namespace detail } // namespace detail
/*! /*!
...@@ -665,6 +724,11 @@ class AttrsNode : public BaseAttrsNode { ...@@ -665,6 +724,11 @@ class AttrsNode : public BaseAttrsNode {
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
} }
void VisitNonDefaultAttrs(AttrVisitor* v) final {
detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16; const int kLinearSearchBound = 16;
......
...@@ -13,7 +13,7 @@ namespace tvm { ...@@ -13,7 +13,7 @@ namespace tvm {
namespace relay { namespace relay {
/*! \brief Attributes used in convolution operators */ /*! \brief Attributes used in convolution operators */
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> { struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides; Array<IndexExpr> strides;
Array<IndexExpr> padding; Array<IndexExpr> padding;
Array<IndexExpr> dilation; Array<IndexExpr> dilation;
...@@ -25,7 +25,7 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> { ...@@ -25,7 +25,7 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
std::string out_layout; std::string out_layout;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") { TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution."); .describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
...@@ -55,14 +55,14 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> { ...@@ -55,14 +55,14 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("__undef__") TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout."); "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none. // use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0)) .set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting"); .describe("Output data type, set to explicit type under mixed precision setting");
} }
}; };
...@@ -123,7 +123,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -123,7 +123,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0)) .set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting"); .describe("Output data type, set to explicit type under mixed precision setting");
} }
}; };
......
...@@ -78,7 +78,7 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -78,7 +78,7 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
.describe("Target shape."); .describe("Target shape.");
TVM_ATTR_FIELD(dtype) TVM_ATTR_FIELD(dtype)
.describe("Target data type.") .describe("Target data type.")
.set_default(Int(0)); .set_default(NullValue<DataType>());
} }
}; // struct InitOpAttrs }; // struct InitOpAttrs
......
...@@ -181,8 +181,6 @@ class FunctionNode : public ExprNode { ...@@ -181,8 +181,6 @@ class FunctionNode : public ExprNode {
public: public:
/*! \brief Function parameters */ /*! \brief Function parameters */
tvm::Array<Var> params; tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*! /*!
* \brief * \brief
* The expression which represents the computation of the function, * The expression which represents the computation of the function,
...@@ -190,6 +188,8 @@ class FunctionNode : public ExprNode { ...@@ -190,6 +188,8 @@ class FunctionNode : public ExprNode {
* or sub-expressions may reference the type variables. * or sub-expressions may reference the type variables.
*/ */
Expr body; Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*! /*!
* \brief Type parameters of the function. * \brief Type parameters of the function.
* Enables the function to vary its type based on these. * Enables the function to vary its type based on these.
...@@ -201,8 +201,8 @@ class FunctionNode : public ExprNode { ...@@ -201,8 +201,8 @@ class FunctionNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params); v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("body", &body); v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params); v->Visit("type_params", &type_params);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -217,8 +217,8 @@ class FunctionNode : public ExprNode { ...@@ -217,8 +217,8 @@ class FunctionNode : public ExprNode {
TVM_DLL FuncType func_type_annotation() const; TVM_DLL FuncType func_type_annotation() const;
TVM_DLL static Function make(tvm::Array<Var> params, TVM_DLL static Function make(tvm::Array<Var> params,
Type ret_type,
Expr body, Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params); tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
......
...@@ -48,6 +48,11 @@ class OpNode : public relay::ExprNode { ...@@ -48,6 +48,11 @@ class OpNode : public relay::ExprNode {
*/ */
std::string attrs_type_key; std::string attrs_type_key;
/*! /*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
*/
uint32_t attrs_type_index{0};
/*!
* \brief number of input arguments to the operator, * \brief number of input arguments to the operator,
* -1 means it is variable length * -1 means it is variable length
*/ */
...@@ -416,6 +421,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) ...@@ -416,6 +421,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
const std::string& type_key) { const std::string& type_key) {
get()->attrs_type_key = type_key; get()->attrs_type_key = type_key;
get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str());
return *this; return *this;
} }
......
...@@ -14,13 +14,15 @@ from . import libinfo ...@@ -14,13 +14,15 @@ from . import libinfo
#---------------------------- #----------------------------
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
string_types = (str,) string_types = (str,)
numeric_types = (float, int, np.float32, np.int32) integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3 # this function is needed for python3
# to convert ctypes.char_p .value back to python str # to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8') py_str = lambda x: x.decode('utf-8')
else: else:
string_types = (basestring,) string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32) integer_types = (int, long, np.int32)
numeric_types = integer_types + (float, np.float32)
py_str = lambda x: x py_str = lambda x: x
......
# pylint: disable=wildcard-import, redefined-builtin # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from . import base from . import base
from . import ty from . import ty
...@@ -19,6 +19,9 @@ from . import image ...@@ -19,6 +19,9 @@ from . import image
# Span # Span
Span = base.Span Span = base.Span
# Env
Environment = env.Environment
# Type # Type
Type = ty.Type Type = ty.Type
TupleType = ty.TupleType TupleType = ty.TupleType
...@@ -40,3 +43,7 @@ Call = expr.Call ...@@ -40,3 +43,7 @@ Call = expr.Call
Let = expr.Let Let = expr.Let
If = expr.If If = expr.If
TupleGetItem = expr.TupleGetItem TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make from . import _make
from . import _expr
NodeBase = NodeBase NodeBase = NodeBase
...@@ -20,7 +21,19 @@ def register_relay_node(type_key=None): ...@@ -20,7 +21,19 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key) return _register_tvm_node(type_key)
class RelayNode(NodeBase):
def astext(self):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
"""
return _expr._text_print(self)
@register_relay_node @register_relay_node
class Span(NodeBase): class Span(RelayNode):
def __init__(self, source, lineno, col_offset): def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program.""" """A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, NodeBase from .base import register_relay_node, RelayNode
from . import _make from . import _make
from . import _env from . import _env
@register_relay_node @register_relay_node
class Environment(NodeBase): class Environment(RelayNode):
"""The global Relay environment containing functions, """The global Relay environment containing functions,
options and more. options and more.
""" """
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from . import _expr import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import ty as _ty
from .._ffi import base as _base, node as _node
from .. import nd as _nd
from .. import convert from .. import convert
class Expr(NodeBase): class Expr(RelayNode):
"""The base type for all Relay expressions.""" """The base type for all Relay expressions."""
@property @property
def checked_type(self): def checked_type(self):
...@@ -56,7 +60,7 @@ class Tuple(Expr): ...@@ -56,7 +60,7 @@ class Tuple(Expr):
@register_relay_node @register_relay_node
class Var(Expr): class Var(Expr):
"""A local variable in Tvm.Relay. """A local variable in Relay.
Local variable can be used to declare input Local variable can be used to declare input
arguments to a function, or intermediate variables. arguments to a function, or intermediate variables.
...@@ -101,26 +105,26 @@ class Function(Expr): ...@@ -101,26 +105,26 @@ class Function(Expr):
params: List[tvm.relay.Var] params: List[tvm.relay.Var]
List of input parameters to the function. List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
body: tvm.relay.Expr body: tvm.relay.Expr
The body of the function. The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]] type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only The additional type parameters, this is only
used in advanced usecase of template functions. used in advanced usecase of template functions.
""" """
def __init__(self, def __init__(self,
params, params,
ret_type,
body, body,
ret_type=None,
type_params=None): type_params=None):
if type_params is None: if type_params is None:
type_params = convert([]) type_params = convert([])
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Function, params, ret_type, body, type_params) _make.Function, params, body, ret_type, type_params)
@register_relay_node @register_relay_node
...@@ -158,7 +162,7 @@ class Let(Expr): ...@@ -158,7 +162,7 @@ class Let(Expr):
Parameters Parameters
---------- ----------
var: tvm.relay.Var variable: tvm.relay.Var
The local variable to be bound. The local variable to be bound.
value: tvm.relay.Expr value: tvm.relay.Expr
...@@ -167,9 +171,9 @@ class Let(Expr): ...@@ -167,9 +171,9 @@ class Let(Expr):
body: tvm.relay.Expr body: tvm.relay.Expr
The body of the let binding. The body of the let binding.
""" """
def __init__(self, var, value, body): def __init__(self, variable, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Let, var, value, body) _make.Let, variable, value, body)
@register_relay_node @register_relay_node
...@@ -208,4 +212,105 @@ class TupleGetItem(Expr): ...@@ -208,4 +212,105 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index) _make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print
class TupleWrapper(_node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple
size: int
The size of the tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
"""Create a new tvm.relay.Var.
This is a simple wrapper function that allows specify
shape and dtype directly.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: Optional[tvm.relay.Type, str]
The type annotation on the variable.
When type_annotation is a str, we will create a scalar variable.
shape: Optional[List[tvm.Expr]]
The shape of the tensor type.
dtype: str, optional
The data type of the tensor.
Examples
--------
.. code-block:: python
# The following 4 lines are equivalent to each other
x = tvm.relay.Var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", shape=[1, 2])
x = tvm.relay.var("x", shape=[1, 2], dtype="float32")
# The following 2 lines are equivalent to each other.
y = tvm.relay.var("x", "float32")
y = tvm.relay.var("x", shape=(), dtype="float32")
"""
if type_annotation is not None and shape is not None:
raise ValueError("Can only specify either type_annotation or shape.")
if shape is not None:
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
def const(value, dtype=None):
"""Create a constant value.
Parameters
----------
value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray]
The constant value.
dtype: str, optional
The data type of the value.
"""
if isinstance(value, _base.numeric_types):
value = _np.array(value, dtype=dtype)
elif isinstance(value, (bool, list)):
value = _np.array(value, dtype=dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
...@@ -11,32 +11,6 @@ from .expr import Expr, Constant, Let, Var, Function, If ...@@ -11,32 +11,6 @@ from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment from .env import Environment
class TupleWrapper(tvm._ffi.node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def _convert_to_value(arg, ctxt=tvm.cpu(0)): def _convert_to_value(arg, ctxt=tvm.cpu(0)):
# type: (Any, tvm.Context) -> tvm.nd.NDArray # type: (Any, tvm.Context) -> tvm.nd.NDArray
"""Convert Python values into the appropriate types """Convert Python values into the appropriate types
...@@ -132,8 +106,8 @@ class PartialFunc(object): ...@@ -132,8 +106,8 @@ class PartialFunc(object):
"""Converts a PartialFunc into a :py:class:`~relay.Function`.""" """Converts a PartialFunc into a :py:class:`~relay.Function`."""
return Function( return Function(
self.params, self.params,
self.ret_type,
self.body, self.body,
self.ret_type,
self.type_params) self.type_params)
#pylint: disable=invalid-name #pylint: disable=invalid-name
...@@ -325,7 +299,7 @@ class IRBuilder(object): ...@@ -325,7 +299,7 @@ class IRBuilder(object):
def _on_exit(): def _on_exit():
bindings, _, _, ret_value = self.exit_scope() bindings, _, _, ret_value = self.exit_scope()
exp = _mk_let(bindings, ret_value) exp = _mk_let(bindings, ret_value)
self.env.add(name, Function(params, ret_type, exp)) self.env.add(name, Function(params, exp, ret_type))
return WithScope(10, _on_exit) return WithScope(10, _on_exit)
......
"""Neural network operations.""" """Neural network operations."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from tvm.relay.ir_builder import TupleWrapper from ...expr import TupleWrapper
from . import _make from . import _make
...@@ -145,7 +145,7 @@ def conv2d_transpose(data, ...@@ -145,7 +145,7 @@ def conv2d_transpose(data,
weight_layout, output_padding, out_dtype) weight_layout, output_padding, out_dtype)
def softmax(data, axis): def softmax(data, axis=1):
r"""Computes softmax. r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
...@@ -158,7 +158,7 @@ def softmax(data, axis): ...@@ -158,7 +158,7 @@ def softmax(data, axis):
data: relay.Expr data: relay.Expr
The input data to the operator. The input data to the operator.
axis: int axis: int, optional
The axis to sum over when computing softmax The axis to sum over when computing softmax
Returns Returns
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language.""" """The type nodes of the Relay language."""
from enum import IntEnum from enum import IntEnum
from .base import NodeBase, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
class Type(NodeBase): class Type(RelayNode):
"""The base type for all Relay types.""" """The base type for all Relay types."""
def __eq__(self, other): def __eq__(self, other):
...@@ -21,27 +21,25 @@ class Type(NodeBase): ...@@ -21,27 +21,25 @@ class Type(NodeBase):
"""Compares two Relay types by referential equality.""" """Compares two Relay types by referential equality."""
return super().__eq__(other) return super().__eq__(other)
@register_relay_node @register_relay_node
class TensorType(Type): class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details. """A concrete TensorType in Relay.
This is the type assigned to tensor's with a known dype and shape. For This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`. example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters Parameters
---------- ----------
shape: list of tvm.Expr shape: List[tvm.Expr]
dtype: str The shape of the Tensor
Returns dtype: str, optional
------- The content data type.
tensor_type: The TensorType """
""" def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
class Kind(IntEnum): class Kind(IntEnum):
......
...@@ -17,11 +17,15 @@ namespace tvm { ...@@ -17,11 +17,15 @@ namespace tvm {
template <typename FType> template <typename FType>
class AttrFunctor; class AttrFunctor;
#define ATTR_FUNCTOR_DEFAULT \
{ return VisitAttrDefault_(op, std::forward<Args>(args)...); }
#define ATTR_FUNCTOR_DISPATCH(OP) \ #define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \ vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \ [](const NodeRef& n, TSelf* self, Args... args) { \
return self->Visit_(static_cast<const OP*>(n.node_.get()), \ return self->VisitAttr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \ std::forward<Args>(args)...); \
}); \ }); \
// A functor for common attribute information. // A functor for common attribute information.
...@@ -40,21 +44,21 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -40,21 +44,21 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
* \param args Additional arguments. * \param args Additional arguments.
* \return The result of the call * \return The result of the call
*/ */
virtual R Visit(const NodeRef& n, Args... args) { virtual R VisitAttr(const NodeRef& n, Args... args) {
static FType vtable = InitVTable(); static FType vtable = InitVTable();
if (vtable.can_dispatch(n)) { if (vtable.can_dispatch(n)) {
return vtable(n, this, std::forward<Args>(args)...); return vtable(n, this, std::forward<Args>(args)...);
} else { } else {
return VisitDefault_(n, std::forward<Args>(args)...); return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
} }
} }
virtual R Visit_(const ArrayNode* op, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const StrMapNode* op, Args... args) = 0; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::IntImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::UIntImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::FloatImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::StringImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitDefault_(const NodeRef& n, Args... args) = 0; virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
private: private:
// initialize the vtable. // initialize the vtable.
......
...@@ -11,6 +11,10 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) { ...@@ -11,6 +11,10 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict); v->Visit("__dict__", &dict);
} }
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs( void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) { const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
...@@ -55,48 +59,48 @@ class AttrsEqualChecker : ...@@ -55,48 +59,48 @@ class AttrsEqualChecker :
if (!equal_) return false; if (!equal_) return false;
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false; if (!lhs.defined() || !rhs.defined()) return false;
if (!this->Visit(lhs, rhs)) { if (!this->VisitAttr(lhs, rhs)) {
equal_ = false; equal_ = false;
} }
return equal_; return equal_;
} }
bool VisitDefault_(const NodeRef& lhs, const NodeRef& other) final { bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final {
if (lhs->derived_from<BaseAttrsNode>()) { if (lhs->derived_from<BaseAttrsNode>()) {
return static_cast<const BaseAttrsNode*>(lhs.get())->ContentEqual(other.get()); return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get());
} }
return lhs.same_as(other); return lhs == other.get();
} }
bool Visit_(const IntImm* lhs, const NodeRef& other) final { bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<IntImm>()) { if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const UIntImm* lhs, const NodeRef& other) final { bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<UIntImm>()) { if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const FloatImm* lhs, const NodeRef& other) final { bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<FloatImm>()) { if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const StringImm* lhs, const NodeRef& other) final { bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StringImm>()) { if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const ArrayNode* lhs, const NodeRef& other) final { bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<ArrayNode>()) { if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false; if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) { for (size_t i = 0; i < lhs->data.size(); ++i) {
...@@ -106,7 +110,7 @@ class AttrsEqualChecker : ...@@ -106,7 +110,7 @@ class AttrsEqualChecker :
return true; return true;
} }
bool Visit_(const StrMapNode* lhs, const NodeRef& other) final { bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StrMapNode>()) { if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false; if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) { for (const auto& kv : lhs->data) {
...@@ -127,38 +131,38 @@ class AttrContentHasher : ...@@ -127,38 +131,38 @@ class AttrContentHasher :
public: public:
size_t result_{0}; size_t result_{0};
void VisitDefault_(const NodeRef& value) final { void VisitAttrDefault_(const Node* value) final {
if (value->derived_from<BaseAttrsNode>()) { if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value.get())->ContentHash()); Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
} else { } else {
Update(NodeHash()(value)); Update(NodeHash()(GetRef<NodeRef>(value)));
} }
} }
void Visit_(const IntImm* op) final { void VisitAttr_(const IntImm* op) final {
Update(std::hash<int64_t>()(op->value)); Update(std::hash<int64_t>()(op->value));
} }
void Visit_(const UIntImm* op) final { void VisitAttr_(const UIntImm* op) final {
Update(std::hash<uint64_t>()(op->value)); Update(std::hash<uint64_t>()(op->value));
} }
void Visit_(const FloatImm* op) final { void VisitAttr_(const FloatImm* op) final {
Update(std::hash<double>()(op->value)); Update(std::hash<double>()(op->value));
} }
void Visit_(const StringImm* op) final { void VisitAttr_(const StringImm* op) final {
Update(std::hash<std::string>()(op->value)); Update(std::hash<std::string>()(op->value));
} }
void Visit_(const ArrayNode* op) final { void VisitAttr_(const ArrayNode* op) final {
Update(op->data.size()); Update(op->data.size());
for (size_t i = 0; i < op->data.size(); ++i) { for (size_t i = 0; i < op->data.size(); ++i) {
this->Visit(NodeRef(op->data[i])); this->VisitAttr(NodeRef(op->data[i]));
} }
} }
void Visit_(const StrMapNode* lhs) final { void VisitAttr_(const StrMapNode* lhs) final {
using Entry = std::pair<std::string, NodePtr<Node> >; using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end()); std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) { std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
...@@ -166,7 +170,7 @@ class AttrContentHasher : ...@@ -166,7 +170,7 @@ class AttrContentHasher :
}); });
for (const Entry& kv : data) { for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first)); Update(std::hash<std::string>()(kv.first));
this->Visit(NodeRef(kv.second)); this->VisitAttr(NodeRef(kv.second));
} }
} }
...@@ -184,7 +188,7 @@ bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) { ...@@ -184,7 +188,7 @@ bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
size_t AttrsHash::Hash(const NodeRef& node) { size_t AttrsHash::Hash(const NodeRef& node) {
if (!node.defined()) return 0; if (!node.defined()) return 0;
AttrContentHasher hasher; AttrContentHasher hasher;
hasher.Visit(node); hasher.VisitAttr(node);
return hasher.result_; return hasher.result_;
} }
......
...@@ -208,6 +208,8 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -208,6 +208,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->type_key = node->type_key(); node_->type_key = node->type_key();
// sepcially handle global object // sepcially handle global object
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key); auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
CHECK(f != nullptr)
<< "Node type \'" << node_->type_key << "\' is not registered in TVM";
if (f->fglobal_key != nullptr) { if (f->fglobal_key != nullptr) {
node_->global_key = f->fglobal_key(node); node_->global_key = f->fglobal_key(node);
return; return;
......
...@@ -51,6 +51,8 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { ...@@ -51,6 +51,8 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
return Span(n); return Span(n);
} }
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span") TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SpanNode::make(args[0], args[1], args[2]); *ret = SpanNode::make(args[0], args[1], args[2]);
......
/*!
* Copyright (c) 2018 by Contributors
* \file debug_printer.cc
* \brief A pretty printer for the Relay IR.
* As we had not determined a formal syntax yet, right now it is only for debug purpose.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/error.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "../pass/type_functor.h"
#include "doc.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
Doc KindDocify(TypeParamNode::Kind k) {
switch (k) {
case TypeParamNode::kShapeVar:
return DocOfStr("ShapeVar");
case TypeParamNode::kShape:
return DocOfStr("Shape");
case TypeParamNode::kBaseType:
return DocOfStr("BaseType");
case TypeParamNode::kType:
return DocOfStr("Type");
default:
LOG(FATAL) << "unreachable code: case not handle in kind";
throw; // log fatal throw but compiler doesnt know
}
}
template<typename T>
std::vector<Doc> MapDocify(const tvm::Array<T>& arr, const std::function<Doc(const T&)>& f) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(f(arr[i]));
}
return vec;
}
template<typename T, typename Hash = std::hash<T>, typename Eq = std::equal_to<T>>
class Counter {
std::unordered_map<T, size_t, Hash, Eq> cnt_;
public:
Counter() = default;
Counter(const Counter&) = delete;
size_t operator()(const T& t) {
auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1;
cnt_[t] = v;
return v;
}
};
std::string Mangle(const std::string& str, size_t s) {
return str + "_" + std::to_string(s);
// return s == 0 ? str : str + "_" + std::to_string(s - 1);
// the above line look prettier but is dangerous:
// suppose we have x, x, x_0. mangling will give x, x_0, x_0!
// the save approach give x_0, x_1, x_0_1, and in fact never clash:
// stripping _([0-9]*) is invert of mangle under all circumstances.
// another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other.
}
constexpr size_t indent = 2;
struct TypeParamName {
bool operator==(const TypeParamName&) const {
return true;
}
};
struct mhash {
size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept {
return 0;
}
};
class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
Environment env;
Counter<TypeParamName, mhash> cnt;
std::unordered_map<TypeParam, Doc, NodeHash, NodeEqual> map;
std::vector<Doc> DocifyTypeArray(const tvm::Array<Type>& arr) {
return MapDocify<Type>(arr, [=](const Type& t) { return Docify(t); });
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
return Docify(tp);
});
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
return MapDocify<TypeConstraint>(arr, [=](const TypeConstraint& tc) { return Docify(tc); });
}
Doc VisitType_(const TensorTypeNode* t) final {
return DocOfStr("tensor");
}
Doc VisitType_(const TypeParamNode* p) final {
auto tp = GetRef<TypeParam>(p);
if (map.count(tp) == 0) {
auto name =
DocOfStr(Mangle("tp", cnt(TypeParamName())) +
std::string(":")) +
KindDocify(p->kind);
map.insert(std::pair<TypeParam, Doc>(tp, name));
}
return map.at(tp);
}
Doc Quantify(const tvm::Array<TypeParam>& tp, const Doc& d) {
if (tp.size() == 0) {
return d;
}
return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d;
}
Doc Constraint(const tvm::Array<TypeConstraint>& tc, const Doc& d) {
if (tc.size() == 0) {
return d;
}
return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d;
}
Doc VisitType_(const FuncTypeNode* f) final {
auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() +
DocOfStr("->") + Sep() + Docify(f->ret_type);
return Group(Quantify(f->type_params,
Constraint(f->type_constraints, inner)));
}
Doc VisitType_(const TypeRelationNode* r) final {
return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")");
}
Doc VisitType_(const TupleTypeNode* t) final {
return Seq("<", DocifyTypeArray(t->fields), ">");
}
Doc VisitType_(const IncompleteTypeNode* i) final {
return DocOfStr("_");
}
public:
TypeDocifier(const Environment& env) : env(env) { }
Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); }
};
class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
Environment env;
Counter<std::string> cnt;
std::unordered_map<Var, std::string, NodeHash, NodeEqual> map;
TypeDocifier td;
std::string VarName(const Var& v) {
if (map.count(v) == 0) {
map.insert(std::pair<Var, std::string>(v, Mangle(v->name_hint, cnt(v->name_hint))));
}
return map.at(v);
}
Doc TypeAnnotation(const Doc& d, const Type& t) {
// test for t being null. probably shouldnt has null. should talk to jared.
if (!t.get() || t.as<IncompleteTypeNode>()) {
return d;
} else {
return d + DocOfStr(":") + td.Docify(t);
}
}
std::vector<Doc> DocifyExprArray(const tvm::Array<Expr>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
std::vector<Doc> vec;
for (Var param : arr) {
vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
param->type_annotation));
}
return vec;
}
Doc VisitExpr_(const ConstantNode* c) final {
return DocOfStr("some_constant");
}
Doc VisitExpr_(const TupleNode* t) final {
return Seq("<", DocifyExprArray(t->fields), ">");
}
Doc VisitExpr_(const VarNode* v) final {
return DocOfStr(VarName(GetRef<Var>(v)));
}
Doc VisitExpr_(const GlobalVarNode* g) final {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
Block(indent, "{", Docify(f->body), "}"));
}
Doc VisitExpr_(const CallNode* c) final {
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() +
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
Doc VisitExpr_(const IfNode* i) final {
return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() +
Block(indent, "{", Docify(i->true_branch), "}") + Sep() +
DocOfStr("else") + Sep() +
Block(indent, "{", Docify(i->false_branch), "}"));
}
Doc VisitExpr_(const OpNode* o) final {
return DocOfStr(o->name);
}
Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
Doc Docify(const Expr& e) { return (*this)(e); }
};
Doc DocOfExpr(const Environment& env, const Expr& expr) {
ExprDocifier d(env);
return d.Docify(expr);
}
Doc DocOfType(const Environment& env, const Type& expr) {
TypeDocifier d(env);
return d.Docify(expr);
}
RDoc ExprRDoc(const Environment& env, const Expr& expr) {
return Layout(DocOfExpr(env, expr));
}
RDoc TypeRDoc(const Environment& env, const Type& expr) {
return Layout(DocOfType(env, expr));
}
std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) {
return os << ExprRDoc(env, e);
}
std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) {
return os << TypeRDoc(env, t);
}
std::string PrintExpr(const Environment& env, const Expr& e) {
std::stringstream ss;
ss << ExprRDoc(env, e);
return ss.str();
}
std::string PrintType(const Environment& env, const Type& t) {
std::stringstream ss;
ss << TypeRDoc(env, t);
return ss.str();
}
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
*ret = PrintExpr(args[0], Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
...@@ -100,6 +100,8 @@ void EnvironmentNode::Merge(const Environment &env) { ...@@ -100,6 +100,8 @@ void EnvironmentNode::Merge(const Environment &env) {
} }
} }
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
TVM_REGISTER_API("relay._make.Environment") TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]); *ret = EnvironmentNode::make(args[0]);
......
...@@ -17,6 +17,8 @@ Constant ConstantNode::make(runtime::NDArray data) { ...@@ -17,6 +17,8 @@ Constant ConstantNode::make(runtime::NDArray data) {
return Constant(n); return Constant(n);
} }
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant") TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]); *ret = ConstantNode::make(args[0]);
...@@ -44,6 +46,8 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { ...@@ -44,6 +46,8 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
return Tuple(n); return Tuple(n);
} }
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple") TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]); *ret = TupleNode::make(args[0]);
...@@ -61,6 +65,8 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { ...@@ -61,6 +65,8 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
return Var(n); return Var(n);
} }
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0], args[1]); *ret = VarNode::make(args[0], args[1]);
...@@ -82,6 +88,8 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { ...@@ -82,6 +88,8 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
return GlobalVar(n); return GlobalVar(n);
} }
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar") TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]); *ret = GlobalVarNode::make(args[0]);
...@@ -94,13 +102,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -94,13 +102,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params, Function FunctionNode::make(tvm::Array<Var> params,
Type ret_type,
Expr body, Expr body,
Type ret_type,
tvm::Array<TypeParam> type_params) { tvm::Array<TypeParam> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>(); NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params); n->params = std::move(params);
n->ret_type = std::move(ret_type);
n->body = std::move(body); n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params); n->type_params = std::move(type_params);
return Function(n); return Function(n);
} }
...@@ -113,6 +121,8 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -113,6 +121,8 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
} }
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]); *ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
...@@ -135,6 +145,8 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, ...@@ -135,6 +145,8 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
return Call(n); return Call(n);
} }
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call") TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]); *ret = CallNode::make(args[0], args[1], args[2], args[3]);
...@@ -154,6 +166,8 @@ Let LetNode::make(Var var, Expr value, Expr body) { ...@@ -154,6 +166,8 @@ Let LetNode::make(Var var, Expr value, Expr body) {
return Let(n); return Let(n);
} }
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let") TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2]); *ret = LetNode::make(args[0], args[1], args[2]);
...@@ -173,6 +187,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { ...@@ -173,6 +187,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
return If(n); return If(n);
} }
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IfNode::make(args[0], args[1], args[2]); *ret = IfNode::make(args[0], args[1], args[2]);
}); });
...@@ -190,6 +206,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { ...@@ -190,6 +206,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
return TupleGetItem(n); return TupleGetItem(n);
} }
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]); *ret = TupleGetItemNode::make(args[0], args[1]);
}); });
......
...@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { ...@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return FunctionNode::make(params, ret_type, body, ty_params); return FunctionNode::make(params, body, ret_type, ty_params);
} }
} }
......
...@@ -22,6 +22,8 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { ...@@ -22,6 +22,8 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype); return TensorTypeNode::make({}, dtype);
} }
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType") TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Array<IndexExpr> shape = args[0]; Array<IndexExpr> shape = args[0];
...@@ -30,8 +32,8 @@ TVM_REGISTER_API("relay._make.TensorType") ...@@ -30,8 +32,8 @@ TVM_REGISTER_API("relay._make.TensorType")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node, .set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) { tvm::IRPrinter *p) {
p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
...@@ -41,6 +43,8 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { ...@@ -41,6 +43,8 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
return TypeParam(n); return TypeParam(n);
} }
TVM_REGISTER_NODE_TYPE(TypeParamNode);
TVM_REGISTER_API("relay._make.TypeParam") TVM_REGISTER_API("relay._make.TypeParam")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1]; int kind = args[1];
...@@ -61,6 +65,8 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { ...@@ -61,6 +65,8 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
return IncompleteType(n); return IncompleteType(n);
} }
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType") TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0]; int kind = args[0];
...@@ -86,6 +92,8 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, ...@@ -86,6 +92,8 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
return FuncType(n); return FuncType(n);
} }
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType") TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
...@@ -111,6 +119,8 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, ...@@ -111,6 +119,8 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
return TypeRelation(n); return TypeRelation(n);
} }
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]); *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
...@@ -129,6 +139,8 @@ TupleType TupleTypeNode::make(Array<Type> fields) { ...@@ -129,6 +139,8 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
return TupleType(n); return TupleType(n);
} }
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType") TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleTypeNode::make(args[0]); *ret = TupleTypeNode::make(args[0]);
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(ConvAttrs); TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
bool Conv2DRel(const Array<Type>& types, bool Conv2DRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -25,7 +25,7 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -25,7 +25,7 @@ bool Conv2DRel(const Array<Type>& types,
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
const ConvAttrs* param = attrs.as<ConvAttrs>(); const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout); const Layout kernel_layout(param->weight_layout);
...@@ -113,7 +113,7 @@ Expr MakeConv2D(Expr data, ...@@ -113,7 +113,7 @@ Expr MakeConv2D(Expr data,
std::string weight_layout, std::string weight_layout,
std::string out_layout, std::string out_layout,
DataType out_dtype) { DataType out_dtype) {
auto attrs = make_node<ConvAttrs>(); auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides); attrs->strides = std::move(strides);
attrs->padding = std::move(padding); attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation); attrs->dilation = std::move(dilation);
...@@ -148,6 +148,7 @@ with the layer input to produce a tensor of outputs. ...@@ -148,6 +148,7 @@ with the layer input to produce a tensor of outputs.
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`. (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
...@@ -296,6 +297,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` ...@@ -296,6 +297,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DTransposeAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
......
...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.dense") ...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.dense")
- **out**: `(x1, x2, ..., xn, units)`. - **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DenseAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.") .add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.") .add_argument("weight", "2D Tensor", "Weight matrix.")
...@@ -107,6 +108,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -107,6 +108,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
`y = x > 0 ? x : alpha * x` `y = x > 0 ? x : alpha * x`
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LeakyReluAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_support_level(3) .set_support_level(3)
...@@ -135,6 +137,7 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -135,6 +137,7 @@ RELAY_REGISTER_OP("nn.softmax")
- **data**: The input data - **data**: The input data
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
...@@ -163,6 +166,7 @@ RELAY_REGISTER_OP("nn.log_softmax") ...@@ -163,6 +166,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
- **data**: The input data - **data**: The input data
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
...@@ -171,9 +175,9 @@ RELAY_REGISTER_OP("nn.log_softmax") ...@@ -171,9 +175,9 @@ RELAY_REGISTER_OP("nn.log_softmax")
// BatchFlatten // BatchFlatten
bool BatchFlattenRel(const Array<Type>& types, bool BatchFlattenRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
...@@ -278,6 +282,7 @@ centered at that value (zero padding is added where necessary). ...@@ -278,6 +282,7 @@ centered at that value (zero padding is added where necessary).
- **data**: The input tensor. - **data**: The input tensor.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LRNAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
...@@ -296,12 +301,12 @@ Expr MakeL2Normalize(Expr data, ...@@ -296,12 +301,12 @@ Expr MakeL2Normalize(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.l2_normalize") TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv); runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
}); });
RELAY_REGISTER_OP("nn.l2_normalize") RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer. .describe(R"code(L2 Normalization layer.
Normalizes along dimension axis using an L2 norm Normalizes along dimension axis using an L2 norm
...@@ -352,6 +357,7 @@ During training, each element of the input is set to zero with probability ``p`` ...@@ -352,6 +357,7 @@ During training, each element of the input is set to zero with probability ``p``
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged. The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DropoutAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.") .add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1) .set_support_level(1)
...@@ -478,6 +484,7 @@ axis to be the last item in the input shape. ...@@ -478,6 +484,7 @@ axis to be the last item in the input shape.
.. note:: .. note::
This operator can be optimized away for inference. This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BatchNormAttrs")
.set_num_inputs(5) .set_num_inputs(5)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.") .add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("gamma", "Tensor", "The gamma scale factor.")
......
...@@ -60,7 +60,7 @@ bool PadRel(const Array<Type>& types, ...@@ -60,7 +60,7 @@ bool PadRel(const Array<Type>& types,
} }
// Handler to create a call to the padding op used by front-end FFI // Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) { Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
auto attrs = make_node<PadAttrs>(); auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value; attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width); attrs->pad_width = std::move(pad_width);
......
...@@ -76,6 +76,7 @@ RELAY_REGISTER_OP("expand_dims") ...@@ -76,6 +76,7 @@ RELAY_REGISTER_OP("expand_dims")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ExpandDimsAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel); .add_type_rel("ExpandDims", ExpandDimsRel);
...@@ -481,6 +482,7 @@ RELAY_REGISTER_OP("zeros") ...@@ -481,6 +482,7 @@ RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros. .describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0) .set_num_inputs(0)
.set_support_level(3) .set_support_level(3)
.add_type_rel("InitOp", InitOpRel); .add_type_rel("InitOp", InitOpRel);
...@@ -503,6 +505,7 @@ RELAY_REGISTER_OP("ones") ...@@ -503,6 +505,7 @@ RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones. .describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0) .set_num_inputs(0)
.set_support_level(3) .set_support_level(3)
.add_type_rel("InitOp", InitOpRel); .add_type_rel("InitOp", InitOpRel);
...@@ -697,6 +700,7 @@ RELAY_REGISTER_OP("squeeze") ...@@ -697,6 +700,7 @@ RELAY_REGISTER_OP("squeeze")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.SqueezeAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel); .add_type_rel("Squeeze", SqueezeRel);
......
...@@ -74,7 +74,10 @@ class CalcDep : private ExprMutator { ...@@ -74,7 +74,10 @@ class CalcDep : private ExprMutator {
} }
Expr VisitExpr_(const FunctionNode* f) final { Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params); return FunctionNode::make(f->params,
Eliminate(f->body),
f->ret_type,
f->type_params);
} }
// generate the let list from dependency graph // generate the let list from dependency graph
......
...@@ -20,6 +20,7 @@ class TypeFunctor; ...@@ -20,6 +20,7 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \ #define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); } { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ #define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \ vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \ [](const NodeRef& n, TSelf* self, Args... args) { \
......
import tvm
from tvm import relay
from tvm.relay.expr import debug_print
from tvm.relay.ir_builder import IRBuilder
ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
show(const)
# should print the array inside?
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
show(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
show(lv)
def test_dup_var():
lv = relay.Var('s')
rv = relay.Var('s')
show(relay.Tuple([lv, rv]))
def test_large_dup_var():
av = relay.Var('s')
bv = relay.Var('s')
cv = relay.Var('s')
show(relay.Tuple([av, bv, cv]))
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
show(gv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = params[0]
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
show(call)
def test_let():
ty = relay.ty.TensorType((10, 20), 'float32')
lv = relay.Var('x', ty)
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv)
show(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
import tvm
import numpy as np
from tvm import relay
do_print = [False]
def show(text):
if do_print[0]:
print("---------------------------")
print(text)
def test_func():
x = relay.var("x", shape=(3, 2))
y = relay.var("y")
one = relay.const(10e10, dtype="float32")
z = relay.add(x, one)
z = relay.add(z, z)
f = relay.Function([x, y], z)
show(z.astext())
show(f.astext())
def test_env():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
z = relay.add(x, y)
z = relay.add(z, z)
f = relay.Function([x, y], z)
env = relay.Environment()
env.add("myf", f)
text = env.astext()
assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text
show(text)
def test_meta_data():
n, c, h, w = tvm.var("n"), 10, 224, 224
x = relay.var("x", shape=(n, c, h, w))
w = relay.var("w")
z = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = f.astext()
assert "channels=2" in text
assert "meta.Variable(id=0)" in text
show(text)
text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text
show(text)
def test_call_attrs():
x = relay.var("x")
# non default args
z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext()
# default args
z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext()
# non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext()
def test_let_if_scope():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
cond = relay.var("cond", "bool")
v1 = relay.var("v")
v2 = relay.var("v", "float32")
then_branch = relay.Let(
v1, relay.const(1, "float32"),
relay.Let(v2, x, relay.subtract(v1, v2)))
v3 = relay.var("v")
let2 = relay.Let(v3, y, v3)
else_branch = relay.add(let2, let2)
result = relay.If(cond, then_branch, else_branch)
f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
if __name__ == "__main__":
do_print[0] = True
test_let_if_scope()
test_func()
test_env()
test_meta_data()
test_call_attrs()
...@@ -10,7 +10,7 @@ def test_well_formed(): ...@@ -10,7 +10,7 @@ def test_well_formed():
let = relay.Let(x, v, x) let = relay.Let(x, v, x)
assert well_formed(let) assert well_formed(let)
assert not well_formed(relay.Let(x, v, let)) assert not well_formed(relay.Let(x, v, let))
f = relay.Function([x], ty, x) f = relay.Function([x], x, ty)
assert well_formed(f) assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing) # this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other. # but we want all binder to be distinct from each other.
......
...@@ -262,49 +262,49 @@ def test_function_alpha_equal(): ...@@ -262,49 +262,49 @@ def test_function_alpha_equal():
basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)] basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2] basic_tps = [tp1, tp2]
func = relay.Function([v1, v2], func = relay.Function([v1, v2], v1,
tt2, v1, basic_tps) tt2, basic_tps)
mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps) mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
assert alpha_equal(func, mapped) assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps) fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, fewer_params) assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Var("v3", tt1), more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2), relay.Var("v4", tt2),
relay.Var("v2", tt2)], tt2, v4, basic_tps) relay.Var("v2", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, more_params) assert not alpha_equal(func, more_params)
params_unordered = relay.Function([v2, v1], params_unordered = relay.Function([v2, v1], v1,
tt2, v1, basic_tps) tt2, basic_tps)
assert not alpha_equal(func, params_unordered) assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([v1, v3], params_mismatch = relay.Function([v1, v3], v1,
tt2, v1, basic_tps) tt2, basic_tps)
assert not alpha_equal(func, params_mismatch) assert not alpha_equal(func, params_mismatch)
# also would not typecheck # also would not typecheck
ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps) ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
assert not alpha_equal(func, ret_type_mismatch) assert not alpha_equal(func, ret_type_mismatch)
# also mis-typed # also mis-typed
different_body = relay.Function(basic_args, tt2, v3, basic_tps) different_body = relay.Function(basic_args, v3, tt2, basic_tps)
assert not alpha_equal(func, different_body) assert not alpha_equal(func, different_body)
fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1]) fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
assert not alpha_equal(func, fewer_type_params) assert not alpha_equal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3]) more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
assert not alpha_equal(func, more_type_params) assert not alpha_equal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1]) type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
assert not alpha_equal(func, type_params_unordered) assert not alpha_equal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4]) different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
assert not alpha_equal(func, different_type_params) assert not alpha_equal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params # a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4])) tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
assert not alpha_equal(func, tupled_example) assert not alpha_equal(func, tupled_example)
......
...@@ -59,7 +59,7 @@ def test_recursion(): ...@@ -59,7 +59,7 @@ def test_recursion():
n = relay.Var("n", e.int32) n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32) data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([n, data], e.float32, funcbody, []) value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0))) orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
......
...@@ -13,7 +13,7 @@ def test_free_vars(): ...@@ -13,7 +13,7 @@ def test_free_vars():
let = relay.Let(x, v, x) let = relay.Let(x, v, x)
fvx = free_vars(let) fvx = free_vars(let)
assert len(free_vars(let)) == 0 assert len(free_vars(let)) == 0
f = relay.Function([x], ty, x) f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0 assert len(free_vars(f)) == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment